SplitModel: Pass to split transformers model (#1436)

## Describe your changes
Added two new passes that split a transformers model into multiple
models. The split is done on the transformers layers.
- `CaptureSplitInfo`: Used the pytorch/hf model to assign split ids to
the transformers layers and saves them as model attributes. The
conversion pass or model builder pass adds the split assignments and
model metadata.
- `SplitModel`: Uses the split_assignments metadata to break the model
into splits.
- If `include_all_nodes` is `True`: Nodes before first split or outside
the splits are assigned to the first split. Nodes after the last split
are assigned to the last split.
- If `include_all_nodes` is `False`: Such nodes are ignored. This means
the embedding, attention mask computation, final norm and lm heads don't
appear in the splits. We expect the user to extract them separately.

`Cache.save_model` now supports saving composite models.
  
## Checklist before requesting a review
- [x] Add unit tests for this change.
- [ ] Make sure all tests can pass.
- [x] Update documents if necessary.
- [x] Lint and apply fixes to your code by running `lintrunner -a`
- [ ] Is this a user-facing change? If yes, give a description of this
change to be included in the release notes.
- [ ] Is this PR including examples changes? If yes, please remember to
update [example
documentation](https://github.com/microsoft/Olive/blob/main/docs/source/examples.md)
in a follow-up PR.

## (Optional) Issue link
This commit is contained in:
Jambay Kinley 2024-10-25 16:19:09 -07:00 коммит произвёл GitHub
Родитель 90695eba2f
Коммит 1e42f0b2b2
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
18 изменённых файлов: 794 добавлений и 43 удалений

Просмотреть файл

@ -144,6 +144,18 @@ ExtractAdapters
---------------
.. autoconfigclass:: olive.passes.ExtractAdapters
.. _capture_split_info:
CaptureSplitInfo
----------------
.. autoconfigclass:: olive.passes.CaptureSplitInfo
.. _split_model:
SplitModel
----------
.. autoconfigclass:: olive.passes.SplitModel
.. _lora:
LoRA

Просмотреть файл

@ -63,5 +63,5 @@
},
"host": "local_system",
"target": "local_system",
"clean_cache": false
"output_dir": "models/llama2_qlora"
}

Просмотреть файл

@ -0,0 +1,21 @@
{
"input_model": {
"type": "HfModel",
"load_kwargs": { "attn_implementation": "eager" },
"model_path": "meta-llama/Llama-2-7b-hf"
},
"systems": {
"local_system": {
"type": "LocalSystem",
"accelerators": [ { "device": "cpu", "execution_providers": [ "CPUExecutionProvider" ] } ]
}
},
"passes": {
"s": { "type": "CaptureSplitInfo", "num_splits": 3 },
"c": { "type": "OnnxConversion", "target_opset": 17, "torch_dtype": "float32" },
"sm": { "type": "SplitModel" }
},
"host": "local_system",
"target": "local_system",
"output_dir": "models/llama2_split"
}

Просмотреть файл

@ -78,3 +78,11 @@ OrtSessionParamsTuning:
EP: null
precision: null
accelerator: null
CaptureSplitInfo:
EP: null
precision: null
accelerator: null
SplitModel:
EP: null
precision: null
accelerator: null

Просмотреть файл

@ -376,9 +376,38 @@ class OliveCache:
output_dir.mkdir(parents=True, exist_ok=True)
model_json = self.load_model(model_id)
if model_json["type"].lower() == "compositemodel":
logger.warning("Saving models of type '%s' is not supported yet.", model_json["type"])
return None
model_json_config = model_json["config"]
copied_components = []
for component_names, component in zip(
model_json_config["model_component_names"], model_json_config["model_components"]
):
copied_components.append(
self._save_model(
component,
output_dir=output_dir,
overwrite=overwrite,
only_cache_files=only_cache_files,
path_prefix=component_names,
)
)
model_json_config["model_components"] = copied_components
model_json = self._save_additional_files(model_json, output_dir)
else:
model_json = self._save_model(model_json, output_dir, overwrite)
# save model json
with (output_dir / "model_config.json").open("w") as f:
json.dump(model_json, f, indent=4)
return model_json
def _save_model(
self,
model_json: dict,
output_dir: str,
overwrite: bool = False,
only_cache_files: bool = False,
path_prefix: str = None,
) -> dict:
# create model object so that we can get the resource paths
model_config: ModelConfig = ModelConfig.from_json(model_json)
resource_paths = model_config.get_resource_paths()
@ -413,17 +442,26 @@ class OliveCache:
continue
# save resource to output directory
model_json["config"][resource_name] = local_resource_path.save_to_dir(
output_dir, resource_name.replace("_path", ""), overwrite
)
path_name = resource_name.replace("_path", "")
if path_prefix:
path_name = f"{path_prefix}_{path_name}"
model_json["config"][resource_name] = local_resource_path.save_to_dir(output_dir, path_name, overwrite)
# we only have additional files for onnx models so saving to "model" is safe
model_path_name = "model"
if path_prefix:
model_path_name = f"{path_prefix}_{model_path_name}"
return self._save_additional_files(model_json, output_dir / model_path_name)
def _save_additional_files(self, model_json: dict, output_dir: Path) -> dict:
# Copy "additional files" to the model folder
# we only have additional files for onnx models so saving to "model" is safe
model_attributes = model_json["config"].get("model_attributes") or {}
additional_files = model_attributes.get("additional_files", [])
for i, src_filepath in enumerate(additional_files):
dst_filepath = output_dir / "model" / Path(src_filepath).name
output_dir.mkdir(parents=True, exist_ok=True)
dst_filepath = output_dir / Path(src_filepath).name
additional_files[i] = str(dst_filepath)
if not dst_filepath.exists():
@ -432,9 +470,6 @@ class OliveCache:
if additional_files:
model_json["config"]["model_attributes"]["additional_files"] = additional_files
# save model json
with (output_dir / "model_config.json").open("w") as f:
json.dump(model_json, f, indent=4)
return model_json
def disable_shared_cache(self):

Просмотреть файл

@ -14,24 +14,42 @@ TASK_TO_PEFT_TASK_TYPE = {
# model_type -> name for layers
MODELS_TO_LAYERS_MAPPING = {
"bloom": "transformer.h",
"falcon": "transformer.h",
"gemma": "model.layers",
"gemma2": "model.layers",
"gpt2": "transformer.h",
"gpt_neox": "gpt_neox.layers",
"gptj": "transformer.h",
"llama": "model.layers",
"mistral": "model.layers",
"opt": "model.decoder.layers",
"phi": "model.layers",
"phi3": "model.layers",
"qwen": "transformer.h",
"qwen2": "model.layers",
}
# model_type -> name for embedding, these are the modules before the first layer
MODELS_TO_EMBEDDINGS_MAPPING = {
"bloom": ["transformer.word_embeddings", "transformer.word_embeddings_layernorm"],
"falcon": ["transformer.word_embeddings"],
"gemma": ["model.embed_tokens"],
"gemma2": ["model.embed_tokens"],
"gpt2": ["transformer.wte", "transformer.wpe"],
"gpt_neox": ["gpt_neox.embed_in"],
"gptj": ["transformer.wte"],
"llama": ["model.embed_tokens"],
"mistral": ["model.embed_tokens"],
"opt": [
"model.decoder.embed_tokens",
"model.decoder.embed_positions",
"model.model.decoder.project_out",
"model.model.decoder.project_in",
"model.decoder.project_out",
"model.decoder.project_in",
],
"phi": ["model.embed_tokens"],
"phi3": ["model.embed_tokens"],
"qwen": ["transformer.wte", "transformer.rotary_emb"],
"qwen2": ["model.embed_tokens"],
}
# model_type -> max length of the model, extracted from the config
@ -41,8 +59,14 @@ MODELS_TO_MAX_LENGTH_MAPPING = {
"bloom": 2048,
"gpt2": "n_positions",
"gpt_neox": "max_position_embeddings",
"gptj": "n_postions",
"llama": "max_position_embeddings",
"mistral": "max_position_embeddings",
"opt": "max_position_embeddings",
"phi": "max_position_embeddings",
"phi3": "max_position_embeddings",
"qwen": "seq_length",
"qwen2": "max_position_embeddings",
}
@ -85,6 +109,4 @@ MODEL_INSIDE_LAYER_MODULES = {
]
}
MODEL_LAYERS_BLOCK_NAME = {"phi3": "model.layers"}
MODELS_TO_LORA_TARGET_MODULES_MAPPING = {"phi3": ["o_proj", "qkv_proj"]}

Просмотреть файл

@ -59,6 +59,7 @@
"VitisQOpQuantizer": { "module_path": "olive.passes.onnx.vitis_ai.quantizer.VitisQOpQuantizer" },
"quantize_static": { "module_path": "olive.passes.onnx.vitis_ai.quantize.quantize_static" },
"PowerOfTwoMethod": { "module_path": "olive.passes.onnx.vitis_ai.quant_utils.PowerOfTwoMethod" },
"SplitModel": { "module_path": "olive.passes.onnx.split.SplitModel" },
"OpenVINOConversion": {
"module_path": "olive.passes.openvino.conversion.OpenVINOConversion",
"extra_dependencies": [ "openvino" ]
@ -67,14 +68,15 @@
"module_path": "olive.passes.openvino.quantization.OpenVINOQuantization",
"extra_dependencies": [ "openvino" ]
},
"GptqQuantizer": {
"module_path": "olive.passes.pytorch.gptq.GptqQuantizer",
"module_dependencies": [ "auto-gptq", "optimum" ]
},
"AutoAWQQuantizer": {
"module_path": "olive.passes.pytorch.autoawq.AutoAWQQuantizer",
"module_dependencies": [ "autoawq" ]
},
"GptqQuantizer": {
"module_path": "olive.passes.pytorch.gptq.GptqQuantizer",
"module_dependencies": [ "auto-gptq", "optimum" ]
},
"CaptureSplitInfo": { "module_path": "olive.passes.pytorch.capture_split_info.CaptureSplitInfo" },
"MergeAdapterWeights": { "module_path": "olive.passes.pytorch.merge_adapter_weights.MergeAdapterWeights" },
"LoftQ": { "module_path": "olive.passes.pytorch.lora.LoftQ" },
"LoRA": { "module_path": "olive.passes.pytorch.lora.LoRA", "extra_dependencies": [ "lora" ] },

Просмотреть файл

@ -246,7 +246,11 @@ class Pass(ABC):
# assumption: the model attributes from passes, if any, are more important than
# the input model attributes, we should not update/extend anymore outside of the pass run
output_model.model_attributes = output_model.model_attributes or model.model_attributes
Pass._carry_forward_additional_files(model, output_model)
if not isinstance(output_model, CompositeModelHandler):
# save and carry forward additional files into the the output model path
# for composite model, the additional_files attribute is already present in the parent
# model_attributes
Pass._carry_forward_additional_files(model, output_model)
return output_model

Просмотреть файл

@ -10,7 +10,6 @@ from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union
import onnx
import onnx.external_data_helper
import torch
from packaging import version
@ -376,6 +375,14 @@ class OnnxConversion(Pass):
pytorch_model, dummy_inputs, io_config, config, device, torch_dtype, tempfile.tempdir
)
model_attributes = deepcopy(model.model_attributes or {})
# add split information if present
split_assignments = model_attributes.get("split_assignments")
if split_assignments:
split_assignment_str = ";".join([f"{k}={v}" for k, v in split_assignments.items()])
onnx.helper.set_model_props(converted_onnx_model, {"split_assignments": split_assignment_str})
# save the model to the output path and return the model
output_model_path = resolve_onnx_path(output_model_path)
output_model = model_proto_to_olive_model(converted_onnx_model, output_model_path, config)

Просмотреть файл

@ -246,6 +246,8 @@ class MatMulNBitsToQDQ(Pass):
dag.add_value_info(vi, graph_idx)
# remove the node
if is_model_output:
dag.remove_output(node_output)
dag.remove_node(node_name)
# rename to original name if it is a model output

Просмотреть файл

@ -10,6 +10,7 @@ import logging
from pathlib import Path
from typing import Any, Dict, Union
import onnx
import transformers
from olive.common.utils import IntEnumBase, StrEnumBase
@ -177,6 +178,8 @@ class ModelBuilder(Pass):
if config[arg] is not None:
extra_args[arg] = "1" if config[arg] else "0"
model_attributes = copy.deepcopy(model.model_attributes or {})
try:
create_model(
model_name=model_path,
@ -189,6 +192,17 @@ class ModelBuilder(Pass):
cache_dir=transformers.utils.TRANSFORMERS_CACHE,
**extra_args,
)
# add split information if present
split_assignments = model_attributes.get("split_assignments")
if split_assignments:
split_assignment_str = ";".join([f"{k}={v}" for k, v in split_assignments.items()])
# load the model and set the split_assignments as model properties
# without the external data so that they can be used as is with the resaved model
model_proto = onnx.load(output_model_filepath, load_external_data=False)
onnx.helper.set_model_props(model_proto, {"split_assignments": split_assignment_str})
onnx.save(model_proto, output_model_filepath)
except Exception:
# if model building fails, clean up the intermediate files in the cache_dir
cache_dir = Path(transformers.utils.TRANSFORMERS_CACHE)
@ -209,7 +223,6 @@ class ModelBuilder(Pass):
json.dump(genai_config, ostrm, indent=4)
# add additional files generated by model builder to model_attributes
model_attributes = copy.deepcopy(model.model_attributes or {})
model_attributes["is_generative"] = True
additional_files = model_attributes.get("additional_files") or []
if metadata_only:

Просмотреть файл

@ -81,9 +81,9 @@ class OnnxIO(ConfigBase):
class OnnxDAG:
"""ONNX model as a directed acyclic graph (DAG)."""
def __init__(self, model: "ModelProto"):
def __init__(self, model: "ModelProto", only_main_graph: bool = False):
self.model = model
self.graphs = self.get_all_graphs(self.model)
self.graphs = self.get_all_graphs(self.model, only_main_graph)
self.nodes: Dict[str, OnnxNode] = {}
self.ios: Dict[str, OnnxIO] = {}
self.connections = defaultdict(list)
@ -95,12 +95,16 @@ class OnnxDAG:
self._process_node(node, self.nodes, self.ios, self.connections, idx)
@staticmethod
def get_all_graphs(model: "ModelProto") -> List[GraphProto]:
def get_all_graphs(model: "ModelProto", only_main_graph: bool = False) -> List[GraphProto]:
"""Get all graphs in the model.
:param model: ONNX model.
:param only_main_graph: whether to return only the main graph.
:return: list of graphs in the model.
"""
if only_main_graph:
return [model.graph]
all_graphs = []
graph_queue = [model.graph]
while graph_queue:
@ -203,8 +207,7 @@ class OnnxDAG:
raise ValueError(f"Output {o} is already connected to another node.")
ios[o].source = name
for destination in ios[o].destination:
if destination != SpecialOutput.OUTPUT:
connections[name].append(destination)
connections[name].append(destination)
def add_input(self, input_proto: ValueInfoProto, graph_idx: int, keep_initializer: bool = False):
"""Add an input to the graph.
@ -224,7 +227,7 @@ class OnnxDAG:
"""
self._add_special_input(initializer, graph_idx, SpecialInput.INITIALIZER, keep_input)
def add_value_info(self, value_info: ValueInfoProto, graph_idx: int):
def add_value_info(self, value_info: ValueInfoProto, graph_idx: int, overwrite: bool = False):
"""Add a value info to the graph.
:param value_info: ValueInfoProto of the value info.
@ -236,8 +239,18 @@ class OnnxDAG:
self.ios[name] = OnnxIO(proto=[value_info], graph_idx=graph_idx)
return
assert not self.ios[name].proto, f"Value info for {name} already exists in the graph."
self.ios[name].proto.append(value_info)
assert (
overwrite or not self.ios[name].proto
), f"Value info for {name} already exists in the graph but overwrite is False."
self.ios[name].proto = [value_info]
def is_io(self, io_name: str) -> bool:
"""Check if an input/output exists in the graph.
:param io_name: name of the input/output.
:return: True if the input/output exists.
"""
return io_name in self.ios
def _add_special_input(
self,
@ -463,21 +476,37 @@ class OnnxDAG:
"""
return self.nodes[node_name].op_type
def get_node_inputs(self, node_name: str) -> List[str]:
def get_node_proto(self, node_name: str) -> NodeProto:
"""Get the node proto.
:param node_name: name of the node.
:return: NodeProto object.
"""
return self.nodes[node_name].proto
def get_node_inputs(self, node_name: str, skip_empty_io: bool = False) -> List[str]:
"""Get the input names of a node.
:param node_name: name of the node.
:param skip_empty_io: whether to skip empty inputs.
:return: list of input names.
"""
return list(self.nodes[node_name].inputs)
inputs = self.nodes[node_name].inputs
if skip_empty_io:
inputs = filter(lambda i: i != "", inputs)
return list(inputs)
def get_node_outputs(self, node_name: str) -> List[str]:
def get_node_outputs(self, node_name: str, skip_empty_io: bool = False) -> List[str]:
"""Get the output names of a node.
:param node_name: name of the node.
:param skip_empty_io: whether to skip empty outputs.
:return: list of output names.
"""
return list(self.nodes[node_name].outputs)
outputs = self.nodes[node_name].outputs
if skip_empty_io:
outputs = filter(lambda o: o != "", outputs)
return list(outputs)
def get_node_attributes(self, node_name: str) -> Dict[str, Any]:
"""Get the attributes of a node.
@ -511,6 +540,19 @@ class OnnxDAG:
"""
return SpecialInput.is_initializer(self.ios[io_name].source)
def is_constant_input(self, io_name: str, allow_input_initializer: bool = False) -> bool:
"""Check if an input/output comes from an initializer or a constant node.
:param io_name: name of the input/output.
:param allow_input_initializer: whether to consider input_initializer as a constant input.
:return: True if the input/output is a constant input.
"""
source = self.ios[io_name].source
if source in self.nodes:
return self.get_node_op_type(source) == "Constant"
return (source == SpecialInput.INITIALIZER) or (allow_input_initializer and SpecialInput.is_initializer(source))
def get_graph_idx(self, name: str) -> int:
"""Get the index of the graph containing the input/output or node."""
if name in self.ios:
@ -544,6 +586,17 @@ class OnnxDAG:
if self.is_output(io_name):
return
self.ios[io_name].destination.append(SpecialOutput.OUTPUT)
self.connections[self.get_producer(io_name)].append(SpecialOutput.OUTPUT)
def remove_output(self, io_name: str):
"""Remove an output from an input/output.
:param io_name: name of the input/output.
"""
if not self.is_output(io_name):
return
self.ios[io_name].destination.remove(SpecialOutput.OUTPUT)
self.connections[self.get_producer(io_name)].remove(SpecialOutput.OUTPUT)
def get_producer(self, io_name: str) -> str:
"""Get the producer of an input/output.
@ -553,16 +606,43 @@ class OnnxDAG:
"""
return self.ios[io_name].source
def get_consumers(self, node_name: str) -> List[str]:
def get_consumers(self, node_name: str, return_special_outputs: bool = False) -> List[str]:
"""Get the consumers of a node.
:param node_name: name of the node. It can also be an input or initializer.
:return: list of names of nodes that consume one/more outputs of the node.
"""
if node_name in self.ios and SpecialInput.is_special_input(self.ios[node_name].source):
return list(self.ios[node_name].destination)
if node_name in self.ios:
consumers = self.ios[node_name].destination
else:
consumers = self.connections[node_name]
return list(self.connections[node_name])
if return_special_outputs:
return list(consumers)
return list(filter(lambda c: c != SpecialOutput.OUTPUT, consumers))
def get_parents(self, node_name: str, return_special_inputs: bool = False) -> List[str]:
"""Get the parents of a node.
:param node_name: name of the node.
:return: list of names of nodes that produce one/more inputs of the node.
"""
parents = [self.ios[i].source for i in self.nodes[node_name].inputs]
if return_special_inputs:
return parents
return list(filter(lambda p: not SpecialInput.is_special_input(p), parents))
def is_input_consumer(self, node_name: str) -> bool:
"""Check if a node is an input consumer.
:param node_name: name of the node.
:return: True if the node consumes one/more inputs that are also model inputs.
"""
return any(SpecialInput.is_special_input(p) for p in self.get_parents(node_name, return_special_inputs=True))
def is_output_producer(self, node_name: str) -> bool:
"""Check if a node is an output producer.
@ -570,7 +650,7 @@ class OnnxDAG:
:param node_name: name of the node.
:return: True if the node produces one/more outputs that are also model outputs.
"""
return any(SpecialOutput.OUTPUT in self.ios[o].destination for o in self.nodes[node_name].outputs)
return SpecialOutput.OUTPUT in self.get_consumers(node_name, return_special_outputs=True)
def _topological_sort_util(self, v: str, visited: Set[str], order: List[str]):
"""Do depth-first search starting from node v.
@ -683,10 +763,11 @@ class OnnxDAG:
logger.debug("Removed %d Identity nodes", len(nodes_to_remove))
@classmethod
def from_model_path(cls, model_path: Union[str, Path]) -> "OnnxDAG":
def from_model_path(cls, model_path: Union[str, Path], only_main_graph: bool = False) -> "OnnxDAG":
"""Load an ONNX model and create an self.
:param model_path: path to the ONNX model.
:param only_main_graph: whether to create a DAG with only the main graph.
:return: OnnxDAG object.
"""
return cls(onnx.load(model_path))
return cls(onnx.load(model_path), only_main_graph=only_main_graph)

280
olive/passes/onnx/split.py Normal file
Просмотреть файл

@ -0,0 +1,280 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
from collections import defaultdict
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict
import numpy as np
import onnx
from olive.hardware.accelerator import AcceleratorSpec
from olive.model import CompositeModelHandler, ONNXModelHandler
from olive.model.utils import resolve_onnx_path
from olive.passes import Pass
from olive.passes.onnx.common import get_external_data_config, model_proto_to_olive_model
from olive.passes.onnx.onnx_dag import OnnxDAG
from olive.passes.pass_config import PassConfigParam
logger = logging.getLogger(__name__)
class SplitModel(Pass):
@classmethod
def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassConfigParam]:
return {
"include_all_nodes": PassConfigParam(
type_=bool,
default_value=True,
description=(
"Include all nodes in the split model. Nodes outside the splits or before the first split will be"
" assigned to the first split. Nodes after the last split will be assigned to the last split. If"
" False, these nodes will not be included in the split models."
),
),
**get_external_data_config(),
}
def _run_for_config(
self, model: ONNXModelHandler, config: Dict[str, Any], output_model_path: str
) -> CompositeModelHandler:
model_proto = model.load_model()
split_assignments = None
for metadata_prop in model_proto.metadata_props:
if metadata_prop.key == "split_assignments":
split_assignments = {
key: int(value)
for key, value in (assignment.split("=") for assignment in metadata_prop.value.split(";"))
}
break
# TODO(jambayk): Should we allow split assignments in the model attributes too?
if not split_assignments:
raise ValueError("No split assignments found in the model metadata")
# TODO(jambayk): Make this more generic, for now only assume transformers layers are split
# so depth of namespace is same for all split assignments
num_splits = len(np.unique(list(split_assignments.values())))
namespace_depth = len(next(iter(split_assignments)).split("."))
# create a dag for the model, won't split nested graphs
dag = OnnxDAG(model_proto, only_main_graph=True)
dag.remove_identity_nodes()
# empy dags for each split
split_proto = onnx.ModelProto(
ir_version=model_proto.ir_version,
opset_import=model_proto.opset_import,
producer_name="olive",
graph=onnx.GraphProto(name=model_proto.graph.name),
)
split_dags = [OnnxDAG(deepcopy(split_proto)) for _ in range(num_splits)]
# go through the nodes in topological order
node_order = dag.topological_sort()
node_assignments = {}
constant_nodes = set()
for node_name in node_order:
# will handle constant nodes laters
if dag.get_node_op_type(node_name) == "Constant":
constant_nodes.add(node_name)
continue
name_components = node_name.replace("/", ".").lstrip(".").split(".")
namespace = ".".join(name_components[:namespace_depth])
if namespace in split_assignments:
node_assignments[node_name] = split_assignments[namespace]
# what is the next closest split, if not assigned to a split
next_split = deepcopy(node_assignments)
# already have a topological order, so we will go from the bottom up
for node_name in node_order[::-1]:
# constants cannot be children of node
if node_name in node_assignments or node_name in constant_nodes:
continue
child_splits = [
next_split[child_name]
for child_name in dag.get_consumers(node_name)
if next_split[child_name] is not None
]
if child_splits:
next_split[node_name] = min(child_splits)
else:
next_split[node_name] = None
if config["include_all_nodes"]:
for node_name in node_order:
if node_name in node_assignments or node_name in constant_nodes:
continue
# parent has a split - after last split or the before/outside parent has been assigned
parent_splits = [
node_assignments[parent_name]
for parent_name in dag.get_parents(node_name)
if parent_name in node_assignments
]
if parent_splits:
node_assignments[node_name] = max(parent_splits)
continue
# before the first split
if next_split[node_name] is not None:
node_assignments[node_name] = next_split[node_name]
continue
# outside the splits
node_assignments[node_name] = 0
else:
# handle unassigned nodes that are:
# - between splits: assign to the split of the parent node
# - between constant/initializer and splits: assign the next split
for node_name in node_order:
if node_name in node_assignments or node_name in constant_nodes:
continue
# after the last split
if next_split[node_name] is None:
continue
# between splits
parent_splits = [
node_assignments[parent_name]
for parent_name in dag.get_parents(node_name)
if parent_name in node_assignments
]
if parent_splits:
node_assignments[node_name] = max(parent_splits)
continue
# between constant/initializer and splits
if all(dag.is_constant_input(input_name) for input_name in dag.get_node_inputs(node_name)):
node_assignments[node_name] = next_split[node_name]
# handle cast nodes of the from:
# - Input -> Cast -> Split
# - Split -> Cast -> Output
for node_name in node_order:
if (
node_name in node_assignments
or dag.get_node_op_type(node_name) != "Cast"
# only one consumer (model output or another node)
or len(dag.get_consumers(node_name, True)) != 1
):
continue
if (
dag.is_input_consumer(node_name)
and (consumer := dag.get_consumers(node_name, True)[0]) in node_assignments
):
node_assignments[node_name] = node_assignments[consumer]
elif (
parent_name := dag.get_parents(node_name, True)[0]
) in node_assignments and dag.is_output_producer(node_name):
node_assignments[node_name] = node_assignments[parent_name]
# handle constant nodes, will add a copy of the constant to each split
for node_name in constant_nodes:
splits = set()
for consumer in dag.get_consumers(node_name):
if consumer in node_assignments:
splits.add(node_assignments[consumer])
if splits:
node_assignments[node_name] = list(splits)
# add the nodes to the split dags
# keep track of missing value info for inputs to the split dags
missing_vi = defaultdict(list)
for node_name in node_order:
split_id = node_assignments.get(node_name)
if split_id is None:
continue
if not isinstance(split_id, list):
# no need to worry about inputs for list split_id since it's only for constants
split_dag = split_dags[split_id]
# add the inputs to the nodes if not already present
for input_name in dag.get_node_inputs(node_name):
if not input_name:
# optional input left as ""
continue
# already added
if split_dag.is_io(input_name):
continue
io = dag.get_io(input_name)
# main graph inputs and/or initializers
if dag.is_input(input_name) or dag.is_initializer(input_name):
if dag.is_input(input_name):
split_dags[split_id].add_input(io.proto[0], 0, True)
if dag.is_initializer(input_name):
split_dags[split_id].add_initializer(io.proto[-1], 0, True)
continue
# cross split inputs
proto = io.proto[0] if io.proto else None
if not proto:
# missing value info
missing_vi[input_name].append(split_id)
proto = onnx.helper.make_empty_tensor_value_info(input_name)
split_dag.add_input(proto, 0)
# add the node to the split dag
split_dag.add_node(dag.get_node_proto(node_name), 0)
# process the node outputs
for output_name in dag.get_node_outputs(node_name):
if not output_name:
# optional output left as ""
continue
# mark as output if any consumer is not in the split
is_output = False
for consumer in dag.get_consumers(output_name, True):
if node_assignments.get(consumer) != split_id:
split_dag.make_output(output_name)
is_output = True
break
# add vi for the outputs
io = dag.get_io(output_name)
if io.proto:
split_dag.add_value_info(io.proto[0], 0)
elif is_output:
# missing value info
missing_vi[output_name].append(split_id)
else:
# add the constant to each split
for idx in split_id:
split_dags[idx].add_node(dag.get_node_proto(node_name), 0)
if missing_vi:
logger.debug("Missing value info for some io. Using onnxruntime shape inference to infer them.")
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
# should we just use the same model proto? might modify dynamic shapes of existing value infos
# if this becomes an issue replace with a newly loaded model proto
shape_inferred_proto = SymbolicShapeInference.infer_shapes(model_proto, auto_merge=True)
shape_inferred_dag = OnnxDAG(shape_inferred_proto, only_main_graph=True)
for input_name, split_ids in missing_vi.items():
io = shape_inferred_dag.get_io(input_name)
if not io.proto:
raise ValueError(f"Missing value info for input {input_name} for split {split_id}")
for idx in split_ids:
split_dags[idx].add_value_info(io.proto[0], 0, overwrite=True)
component_models = []
component_names = []
for i, split_dag in enumerate(split_dags):
split_name = f"split_{i}"
split_dir = Path(output_model_path).with_suffix("") / split_name
split_path = resolve_onnx_path(split_dir, f"{split_name}.onnx")
split_dag.update()
component_models.append(model_proto_to_olive_model(split_dag.model, split_path, config))
component_names.append(split_name)
return CompositeModelHandler(component_models, component_names)

Просмотреть файл

@ -0,0 +1,76 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
from copy import deepcopy
from typing import Any, Dict, Union
import numpy as np
from olive.common.hf.mappings import MODELS_TO_LAYERS_MAPPING
from olive.common.utils import get_attr
from olive.hardware.accelerator import AcceleratorSpec
from olive.model import HfModelHandler, PyTorchModelHandler
from olive.passes import Pass
from olive.passes.pass_config import PassConfigParam
logger = logging.getLogger(__name__)
class CaptureSplitInfo(Pass):
"""Capture the split information of the model layers. Only splits the transformer layers."""
@classmethod
def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassConfigParam]:
return {
"num_splits": PassConfigParam(
type_=int,
required=True,
description="Number of splits to divide the model layers into.",
),
"block_to_split": PassConfigParam(
type_=str,
default_value=None,
description=(
"Name of the model block to split. Children of the block will be divided into the splits. For"
" supported transformers models, the default value is the transformers layer block name. Refer to"
" olive.common.hf.mappings.MODELS_TO_LAYERS_MAPPING for supported models."
),
),
}
def _run_for_config(
self, model: Union[HfModelHandler, PyTorchModelHandler], config: Dict[str, Any], output_model_path: str
) -> Union[HfModelHandler, PyTorchModelHandler]:
block_to_split = config["block_to_split"]
# check for None specifically since "" is a valid value
if block_to_split is None and isinstance(model, HfModelHandler):
model_type = model.get_hf_model_type()
block_to_split = MODELS_TO_LAYERS_MAPPING.get(model_type, None)
if block_to_split is None:
raise ValueError("block_to_split is not set and could not be inferred. Please set it manually.")
block_members = []
# we could get the number of layers for hf model from the model attributes
# but will just load the model to make the logic simple for now
# consider loading with meta device to avoid loading the weights
loaded_model = model.load_model(cache_model=False)
block = get_attr(loaded_model, block_to_split)
if block is None:
raise ValueError(f"block_to_split {block_to_split} not found in model.")
for child_name, _ in block.named_children():
block_members.append(child_name)
split_assignments = {}
for split_idx, split_members in enumerate(np.array_split(block_members, config["num_splits"])):
for child_name in split_members:
split_assignments[f"{block_to_split}.{child_name}".lstrip(".")] = split_idx
# create a copy of the iput model and add the split assignments as a new attribute
model.model = None
output_model = deepcopy(model)
output_model.model_attributes = model_attributes = output_model.model_attributes or {}
model_attributes["split_assignments"] = split_assignments
return output_model

Просмотреть файл

@ -12,7 +12,7 @@ from typing import Any, Dict, List, Union
import torch
from olive.common.config_utils import validate_config
from olive.common.hf.mappings import MODEL_INSIDE_LAYER_MODULES, MODEL_LAYERS_BLOCK_NAME, MODEL_OUTSIDE_LAYER_MODULES
from olive.common.hf.mappings import MODEL_INSIDE_LAYER_MODULES, MODEL_OUTSIDE_LAYER_MODULES, MODELS_TO_LAYERS_MAPPING
from olive.data.config import DataConfig
from olive.hardware.accelerator import AcceleratorSpec
from olive.model import HfModelHandler, PyTorchModelHandler
@ -161,7 +161,7 @@ class GptqQuantizer(Pass):
fields_to_set = {
"outside_layer_modules": MODEL_OUTSIDE_LAYER_MODULES,
"inside_layer_modules": MODEL_INSIDE_LAYER_MODULES,
"layers_block_name": MODEL_LAYERS_BLOCK_NAME,
"layers_block_name": MODELS_TO_LAYERS_MAPPING,
}
for key, value in fields_to_set.items():
if config[key]:

Просмотреть файл

@ -116,7 +116,7 @@ def input_model_info_fixture(tmp_path_factory):
@pytest.mark.parametrize("model_type", [None, "float", "int4"])
def test_model_has_adapters(tmp_path, input_model_info, model_type):
def test_model_has_adapters(input_model_info, model_type):
if model_type is None:
assert not model_has_adapters(get_onnx_model().model_path)
else:

Просмотреть файл

@ -0,0 +1,129 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import pytest
from olive.model import CompositeModelHandler, HfModelHandler, ONNXModelHandler
from olive.passes.olive_pass import create_pass_from_dict
from olive.passes.onnx.conversion import OnnxConversion
from olive.passes.onnx.split import SplitModel
from olive.passes.onnx.transformer_optimization import OrtTransformersOptimization
# TODO(jambayk): Add model builder and qdq models to this test
@pytest.fixture(name="input_model_info", scope="module")
def input_model_info_fixture(tmp_path_factory):
# this tmp_path exists for the duration of the test session
# module is scope is used to ensure that the fixture is only created once
tmp_path = tmp_path_factory.mktemp("test-split-model")
# store onnx models for use in tests
all_models = {}
# input model
input_model = HfModelHandler(
model_path="katuni4ka/tiny-random-phi3",
load_kwargs={"trust_remote_code": False, "revision": "585361abfee667f3c63f8b2dc4ad58405c4e34e2"},
model_attributes={"split_assignments": {"model.layers.0": 0, "model.layers.1": 1}},
)
# conversion fp32
all_models["convert_fp32"] = create_pass_from_dict(
OnnxConversion, {"torch_dtype": "float32"}, disable_search=True
).run(input_model, tmp_path / "convert_fp32")
# transformers opt fp32
all_models["opt_fp32"] = create_pass_from_dict(
OrtTransformersOptimization, {"model_type": "bert", "opt_level": 0}, disable_search=True
).run(all_models["convert_fp32"], tmp_path / "opt_fp32")
# transformers opt fp16
all_models["opt_fp16"] = create_pass_from_dict(
OrtTransformersOptimization,
{"model_type": "bert", "opt_level": 0, "float16": True, "keep_io_types": False},
disable_search=True,
).run(all_models["convert_fp32"], tmp_path / "opt_fp16")
# transformers opt fp16 with keep_io_types
all_models["opt_fp16_keep_io_types"] = create_pass_from_dict(
OrtTransformersOptimization,
{"model_type": "bert", "opt_level": 0, "float16": True, "keep_io_types": True},
disable_search=True,
).run(all_models["convert_fp32"], tmp_path / "opt_fp16_keep_io_types")
return all_models
def common_check(tmp_path, input_model_info, model_type, include_all_nodes):
input_model = input_model_info[model_type]
split_model = create_pass_from_dict(SplitModel, {"include_all_nodes": include_all_nodes}, disable_search=True).run(
input_model, tmp_path
)
assert isinstance(split_model, CompositeModelHandler)
components = list(split_model.model_components)
assert len(components) == 2
assert all(isinstance(component, ONNXModelHandler) for component in components)
# check that the split models can be loaded
# TODO(jambayk): Consider running the full model and comparing the outputs with the splits
# this is more involved. have to modify the input model to create outputs for the split outputs
for component in components:
component.prepare_session(execution_providers="CPUExecutionProvider")
# check that the splits have the expected kv inputs
for split_idx in range(2):
io_config = components[split_idx].io_config
expected_kv_inputs = [f"past_key_values.{split_idx}.key", f"past_key_values.{split_idx}.value"]
if not include_all_nodes and model_type == "opt_fp16_keep_io_types":
# key is used before the firt split too
expected_kv_inputs.remove(f"past_key_values.{split_idx}.key")
assert set(expected_kv_inputs) <= set(io_config["input_names"])
expected_kv_outputs = [f"present.{split_idx}.key", f"present.{split_idx}.value"]
assert set(expected_kv_outputs) <= set(io_config["output_names"])
return input_model, components
@pytest.mark.parametrize(
"model_type",
["convert_fp32", "opt_fp32", "opt_fp16", "opt_fp16_keep_io_types"],
)
def test_split_model_all_nodes(tmp_path, input_model_info, model_type):
input_model, components = common_check(tmp_path, input_model_info, model_type, True)
input_io_config = input_model.io_config
# input of first split must be a subset of the input of the input model
assert set(components[0].io_config["input_names"]) <= set(input_io_config["input_names"])
# input of second split must be a subset of model input + first split output
assert set(components[1].io_config["input_names"]) <= set(
input_io_config["input_names"] + components[0].io_config["output_names"]
)
# output of first split not used by second split must be a subset of the output of the input model
split_1_model_output = set(components[0].io_config["output_names"]) - set(components[1].io_config["input_names"])
assert split_1_model_output <= set(input_io_config["output_names"])
# output of second split must be a subset of the output of the input model
split_2_model_output = set(components[1].io_config["output_names"])
assert split_2_model_output <= set(input_io_config["output_names"])
# all split model outputs must be the same as the input model outputs
assert (split_1_model_output | split_2_model_output) == set(input_io_config["output_names"])
@pytest.mark.parametrize(
("model_type", "expected_overlap"),
[
("convert_fp32", 3),
("opt_fp32", 4),
("opt_fp16", 4),
("opt_fp16_keep_io_types", 4),
],
)
def test_split_model_only_splits(tmp_path, input_model_info, model_type, expected_overlap):
_, components = common_check(tmp_path, input_model_info, model_type, False)
# check that the splits have the expected overlap
assert (
len(set(components[0].io_config["output_names"]).intersection(set(components[1].io_config["input_names"])))
== expected_overlap
)

Просмотреть файл

@ -0,0 +1,59 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import pytest
import torch
from olive.model import HfModelHandler, PyTorchModelHandler
from olive.passes.olive_pass import create_pass_from_dict
from olive.passes.pytorch.capture_split_info import CaptureSplitInfo
class CustomModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.before_layer = torch.nn.Linear(2, 4)
self.layers = torch.nn.ModuleList([torch.nn.Linear(4, 4) for _ in range(4)])
self.after_layer = torch.nn.Linear(4, 2)
def forward(self, x):
x = self.before_layer(x)
for layer in self.layers:
x = layer(x)
return self.after_layer(x)
@pytest.mark.parametrize(
("input_model", "block_to_split", "num_splits", "split_assignments"),
[
(
PyTorchModelHandler(model_loader=lambda _: CustomModel()),
"layers",
2,
{"layers.0": 0, "layers.1": 0, "layers.2": 1, "layers.3": 1},
),
(
PyTorchModelHandler(model_loader=lambda _: CustomModel()),
"",
3,
{"before_layer": 0, "layers": 1, "after_layer": 2},
),
(
HfModelHandler(model_path="hf-internal-testing/tiny-random-LlamaForCausalLM"),
None,
2,
{"model.layers.0": 0, "model.layers.1": 1},
),
],
)
def test_capture_split_info(input_model, block_to_split, num_splits, split_assignments, tmp_path):
config = {
"num_splits": num_splits,
}
if block_to_split is not None:
config["block_to_split"] = block_to_split
p = create_pass_from_dict(CaptureSplitInfo, config, disable_search=True)
out = p.run(input_model, str(tmp_path))
assert out.model_attributes["split_assignments"] == split_assignments