SplitModel: Pass to split transformers model (#1436)
## Describe your changes Added two new passes that split a transformers model into multiple models. The split is done on the transformers layers. - `CaptureSplitInfo`: Used the pytorch/hf model to assign split ids to the transformers layers and saves them as model attributes. The conversion pass or model builder pass adds the split assignments and model metadata. - `SplitModel`: Uses the split_assignments metadata to break the model into splits. - If `include_all_nodes` is `True`: Nodes before first split or outside the splits are assigned to the first split. Nodes after the last split are assigned to the last split. - If `include_all_nodes` is `False`: Such nodes are ignored. This means the embedding, attention mask computation, final norm and lm heads don't appear in the splits. We expect the user to extract them separately. `Cache.save_model` now supports saving composite models. ## Checklist before requesting a review - [x] Add unit tests for this change. - [ ] Make sure all tests can pass. - [x] Update documents if necessary. - [x] Lint and apply fixes to your code by running `lintrunner -a` - [ ] Is this a user-facing change? If yes, give a description of this change to be included in the release notes. - [ ] Is this PR including examples changes? If yes, please remember to update [example documentation](https://github.com/microsoft/Olive/blob/main/docs/source/examples.md) in a follow-up PR. ## (Optional) Issue link
This commit is contained in:
Родитель
90695eba2f
Коммит
1e42f0b2b2
|
@ -144,6 +144,18 @@ ExtractAdapters
|
|||
---------------
|
||||
.. autoconfigclass:: olive.passes.ExtractAdapters
|
||||
|
||||
.. _capture_split_info:
|
||||
|
||||
CaptureSplitInfo
|
||||
----------------
|
||||
.. autoconfigclass:: olive.passes.CaptureSplitInfo
|
||||
|
||||
.. _split_model:
|
||||
|
||||
SplitModel
|
||||
----------
|
||||
.. autoconfigclass:: olive.passes.SplitModel
|
||||
|
||||
.. _lora:
|
||||
|
||||
LoRA
|
||||
|
|
|
@ -63,5 +63,5 @@
|
|||
},
|
||||
"host": "local_system",
|
||||
"target": "local_system",
|
||||
"clean_cache": false
|
||||
"output_dir": "models/llama2_qlora"
|
||||
}
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
{
|
||||
"input_model": {
|
||||
"type": "HfModel",
|
||||
"load_kwargs": { "attn_implementation": "eager" },
|
||||
"model_path": "meta-llama/Llama-2-7b-hf"
|
||||
},
|
||||
"systems": {
|
||||
"local_system": {
|
||||
"type": "LocalSystem",
|
||||
"accelerators": [ { "device": "cpu", "execution_providers": [ "CPUExecutionProvider" ] } ]
|
||||
}
|
||||
},
|
||||
"passes": {
|
||||
"s": { "type": "CaptureSplitInfo", "num_splits": 3 },
|
||||
"c": { "type": "OnnxConversion", "target_opset": 17, "torch_dtype": "float32" },
|
||||
"sm": { "type": "SplitModel" }
|
||||
},
|
||||
"host": "local_system",
|
||||
"target": "local_system",
|
||||
"output_dir": "models/llama2_split"
|
||||
}
|
|
@ -78,3 +78,11 @@ OrtSessionParamsTuning:
|
|||
EP: null
|
||||
precision: null
|
||||
accelerator: null
|
||||
CaptureSplitInfo:
|
||||
EP: null
|
||||
precision: null
|
||||
accelerator: null
|
||||
SplitModel:
|
||||
EP: null
|
||||
precision: null
|
||||
accelerator: null
|
||||
|
|
|
@ -376,9 +376,38 @@ class OliveCache:
|
|||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
model_json = self.load_model(model_id)
|
||||
if model_json["type"].lower() == "compositemodel":
|
||||
logger.warning("Saving models of type '%s' is not supported yet.", model_json["type"])
|
||||
return None
|
||||
model_json_config = model_json["config"]
|
||||
copied_components = []
|
||||
for component_names, component in zip(
|
||||
model_json_config["model_component_names"], model_json_config["model_components"]
|
||||
):
|
||||
copied_components.append(
|
||||
self._save_model(
|
||||
component,
|
||||
output_dir=output_dir,
|
||||
overwrite=overwrite,
|
||||
only_cache_files=only_cache_files,
|
||||
path_prefix=component_names,
|
||||
)
|
||||
)
|
||||
model_json_config["model_components"] = copied_components
|
||||
model_json = self._save_additional_files(model_json, output_dir)
|
||||
else:
|
||||
model_json = self._save_model(model_json, output_dir, overwrite)
|
||||
|
||||
# save model json
|
||||
with (output_dir / "model_config.json").open("w") as f:
|
||||
json.dump(model_json, f, indent=4)
|
||||
return model_json
|
||||
|
||||
def _save_model(
|
||||
self,
|
||||
model_json: dict,
|
||||
output_dir: str,
|
||||
overwrite: bool = False,
|
||||
only_cache_files: bool = False,
|
||||
path_prefix: str = None,
|
||||
) -> dict:
|
||||
# create model object so that we can get the resource paths
|
||||
model_config: ModelConfig = ModelConfig.from_json(model_json)
|
||||
resource_paths = model_config.get_resource_paths()
|
||||
|
@ -413,17 +442,26 @@ class OliveCache:
|
|||
continue
|
||||
|
||||
# save resource to output directory
|
||||
model_json["config"][resource_name] = local_resource_path.save_to_dir(
|
||||
output_dir, resource_name.replace("_path", ""), overwrite
|
||||
)
|
||||
path_name = resource_name.replace("_path", "")
|
||||
if path_prefix:
|
||||
path_name = f"{path_prefix}_{path_name}"
|
||||
model_json["config"][resource_name] = local_resource_path.save_to_dir(output_dir, path_name, overwrite)
|
||||
|
||||
# we only have additional files for onnx models so saving to "model" is safe
|
||||
model_path_name = "model"
|
||||
if path_prefix:
|
||||
model_path_name = f"{path_prefix}_{model_path_name}"
|
||||
return self._save_additional_files(model_json, output_dir / model_path_name)
|
||||
|
||||
def _save_additional_files(self, model_json: dict, output_dir: Path) -> dict:
|
||||
# Copy "additional files" to the model folder
|
||||
# we only have additional files for onnx models so saving to "model" is safe
|
||||
model_attributes = model_json["config"].get("model_attributes") or {}
|
||||
additional_files = model_attributes.get("additional_files", [])
|
||||
|
||||
for i, src_filepath in enumerate(additional_files):
|
||||
dst_filepath = output_dir / "model" / Path(src_filepath).name
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
dst_filepath = output_dir / Path(src_filepath).name
|
||||
additional_files[i] = str(dst_filepath)
|
||||
|
||||
if not dst_filepath.exists():
|
||||
|
@ -432,9 +470,6 @@ class OliveCache:
|
|||
if additional_files:
|
||||
model_json["config"]["model_attributes"]["additional_files"] = additional_files
|
||||
|
||||
# save model json
|
||||
with (output_dir / "model_config.json").open("w") as f:
|
||||
json.dump(model_json, f, indent=4)
|
||||
return model_json
|
||||
|
||||
def disable_shared_cache(self):
|
||||
|
|
|
@ -14,24 +14,42 @@ TASK_TO_PEFT_TASK_TYPE = {
|
|||
# model_type -> name for layers
|
||||
MODELS_TO_LAYERS_MAPPING = {
|
||||
"bloom": "transformer.h",
|
||||
"falcon": "transformer.h",
|
||||
"gemma": "model.layers",
|
||||
"gemma2": "model.layers",
|
||||
"gpt2": "transformer.h",
|
||||
"gpt_neox": "gpt_neox.layers",
|
||||
"gptj": "transformer.h",
|
||||
"llama": "model.layers",
|
||||
"mistral": "model.layers",
|
||||
"opt": "model.decoder.layers",
|
||||
"phi": "model.layers",
|
||||
"phi3": "model.layers",
|
||||
"qwen": "transformer.h",
|
||||
"qwen2": "model.layers",
|
||||
}
|
||||
|
||||
# model_type -> name for embedding, these are the modules before the first layer
|
||||
MODELS_TO_EMBEDDINGS_MAPPING = {
|
||||
"bloom": ["transformer.word_embeddings", "transformer.word_embeddings_layernorm"],
|
||||
"falcon": ["transformer.word_embeddings"],
|
||||
"gemma": ["model.embed_tokens"],
|
||||
"gemma2": ["model.embed_tokens"],
|
||||
"gpt2": ["transformer.wte", "transformer.wpe"],
|
||||
"gpt_neox": ["gpt_neox.embed_in"],
|
||||
"gptj": ["transformer.wte"],
|
||||
"llama": ["model.embed_tokens"],
|
||||
"mistral": ["model.embed_tokens"],
|
||||
"opt": [
|
||||
"model.decoder.embed_tokens",
|
||||
"model.decoder.embed_positions",
|
||||
"model.model.decoder.project_out",
|
||||
"model.model.decoder.project_in",
|
||||
"model.decoder.project_out",
|
||||
"model.decoder.project_in",
|
||||
],
|
||||
"phi": ["model.embed_tokens"],
|
||||
"phi3": ["model.embed_tokens"],
|
||||
"qwen": ["transformer.wte", "transformer.rotary_emb"],
|
||||
"qwen2": ["model.embed_tokens"],
|
||||
}
|
||||
|
||||
# model_type -> max length of the model, extracted from the config
|
||||
|
@ -41,8 +59,14 @@ MODELS_TO_MAX_LENGTH_MAPPING = {
|
|||
"bloom": 2048,
|
||||
"gpt2": "n_positions",
|
||||
"gpt_neox": "max_position_embeddings",
|
||||
"gptj": "n_postions",
|
||||
"llama": "max_position_embeddings",
|
||||
"mistral": "max_position_embeddings",
|
||||
"opt": "max_position_embeddings",
|
||||
"phi": "max_position_embeddings",
|
||||
"phi3": "max_position_embeddings",
|
||||
"qwen": "seq_length",
|
||||
"qwen2": "max_position_embeddings",
|
||||
}
|
||||
|
||||
|
||||
|
@ -85,6 +109,4 @@ MODEL_INSIDE_LAYER_MODULES = {
|
|||
]
|
||||
}
|
||||
|
||||
MODEL_LAYERS_BLOCK_NAME = {"phi3": "model.layers"}
|
||||
|
||||
MODELS_TO_LORA_TARGET_MODULES_MAPPING = {"phi3": ["o_proj", "qkv_proj"]}
|
||||
|
|
|
@ -59,6 +59,7 @@
|
|||
"VitisQOpQuantizer": { "module_path": "olive.passes.onnx.vitis_ai.quantizer.VitisQOpQuantizer" },
|
||||
"quantize_static": { "module_path": "olive.passes.onnx.vitis_ai.quantize.quantize_static" },
|
||||
"PowerOfTwoMethod": { "module_path": "olive.passes.onnx.vitis_ai.quant_utils.PowerOfTwoMethod" },
|
||||
"SplitModel": { "module_path": "olive.passes.onnx.split.SplitModel" },
|
||||
"OpenVINOConversion": {
|
||||
"module_path": "olive.passes.openvino.conversion.OpenVINOConversion",
|
||||
"extra_dependencies": [ "openvino" ]
|
||||
|
@ -67,14 +68,15 @@
|
|||
"module_path": "olive.passes.openvino.quantization.OpenVINOQuantization",
|
||||
"extra_dependencies": [ "openvino" ]
|
||||
},
|
||||
"GptqQuantizer": {
|
||||
"module_path": "olive.passes.pytorch.gptq.GptqQuantizer",
|
||||
"module_dependencies": [ "auto-gptq", "optimum" ]
|
||||
},
|
||||
"AutoAWQQuantizer": {
|
||||
"module_path": "olive.passes.pytorch.autoawq.AutoAWQQuantizer",
|
||||
"module_dependencies": [ "autoawq" ]
|
||||
},
|
||||
"GptqQuantizer": {
|
||||
"module_path": "olive.passes.pytorch.gptq.GptqQuantizer",
|
||||
"module_dependencies": [ "auto-gptq", "optimum" ]
|
||||
},
|
||||
"CaptureSplitInfo": { "module_path": "olive.passes.pytorch.capture_split_info.CaptureSplitInfo" },
|
||||
"MergeAdapterWeights": { "module_path": "olive.passes.pytorch.merge_adapter_weights.MergeAdapterWeights" },
|
||||
"LoftQ": { "module_path": "olive.passes.pytorch.lora.LoftQ" },
|
||||
"LoRA": { "module_path": "olive.passes.pytorch.lora.LoRA", "extra_dependencies": [ "lora" ] },
|
||||
|
|
|
@ -246,7 +246,11 @@ class Pass(ABC):
|
|||
# assumption: the model attributes from passes, if any, are more important than
|
||||
# the input model attributes, we should not update/extend anymore outside of the pass run
|
||||
output_model.model_attributes = output_model.model_attributes or model.model_attributes
|
||||
Pass._carry_forward_additional_files(model, output_model)
|
||||
if not isinstance(output_model, CompositeModelHandler):
|
||||
# save and carry forward additional files into the the output model path
|
||||
# for composite model, the additional_files attribute is already present in the parent
|
||||
# model_attributes
|
||||
Pass._carry_forward_additional_files(model, output_model)
|
||||
|
||||
return output_model
|
||||
|
||||
|
|
|
@ -10,7 +10,6 @@ from pathlib import Path
|
|||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import onnx
|
||||
import onnx.external_data_helper
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
|
@ -376,6 +375,14 @@ class OnnxConversion(Pass):
|
|||
pytorch_model, dummy_inputs, io_config, config, device, torch_dtype, tempfile.tempdir
|
||||
)
|
||||
|
||||
model_attributes = deepcopy(model.model_attributes or {})
|
||||
|
||||
# add split information if present
|
||||
split_assignments = model_attributes.get("split_assignments")
|
||||
if split_assignments:
|
||||
split_assignment_str = ";".join([f"{k}={v}" for k, v in split_assignments.items()])
|
||||
onnx.helper.set_model_props(converted_onnx_model, {"split_assignments": split_assignment_str})
|
||||
|
||||
# save the model to the output path and return the model
|
||||
output_model_path = resolve_onnx_path(output_model_path)
|
||||
output_model = model_proto_to_olive_model(converted_onnx_model, output_model_path, config)
|
||||
|
|
|
@ -246,6 +246,8 @@ class MatMulNBitsToQDQ(Pass):
|
|||
dag.add_value_info(vi, graph_idx)
|
||||
|
||||
# remove the node
|
||||
if is_model_output:
|
||||
dag.remove_output(node_output)
|
||||
dag.remove_node(node_name)
|
||||
|
||||
# rename to original name if it is a model output
|
||||
|
|
|
@ -10,6 +10,7 @@ import logging
|
|||
from pathlib import Path
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import onnx
|
||||
import transformers
|
||||
|
||||
from olive.common.utils import IntEnumBase, StrEnumBase
|
||||
|
@ -177,6 +178,8 @@ class ModelBuilder(Pass):
|
|||
if config[arg] is not None:
|
||||
extra_args[arg] = "1" if config[arg] else "0"
|
||||
|
||||
model_attributes = copy.deepcopy(model.model_attributes or {})
|
||||
|
||||
try:
|
||||
create_model(
|
||||
model_name=model_path,
|
||||
|
@ -189,6 +192,17 @@ class ModelBuilder(Pass):
|
|||
cache_dir=transformers.utils.TRANSFORMERS_CACHE,
|
||||
**extra_args,
|
||||
)
|
||||
|
||||
# add split information if present
|
||||
split_assignments = model_attributes.get("split_assignments")
|
||||
if split_assignments:
|
||||
split_assignment_str = ";".join([f"{k}={v}" for k, v in split_assignments.items()])
|
||||
|
||||
# load the model and set the split_assignments as model properties
|
||||
# without the external data so that they can be used as is with the resaved model
|
||||
model_proto = onnx.load(output_model_filepath, load_external_data=False)
|
||||
onnx.helper.set_model_props(model_proto, {"split_assignments": split_assignment_str})
|
||||
onnx.save(model_proto, output_model_filepath)
|
||||
except Exception:
|
||||
# if model building fails, clean up the intermediate files in the cache_dir
|
||||
cache_dir = Path(transformers.utils.TRANSFORMERS_CACHE)
|
||||
|
@ -209,7 +223,6 @@ class ModelBuilder(Pass):
|
|||
json.dump(genai_config, ostrm, indent=4)
|
||||
|
||||
# add additional files generated by model builder to model_attributes
|
||||
model_attributes = copy.deepcopy(model.model_attributes or {})
|
||||
model_attributes["is_generative"] = True
|
||||
additional_files = model_attributes.get("additional_files") or []
|
||||
if metadata_only:
|
||||
|
|
|
@ -81,9 +81,9 @@ class OnnxIO(ConfigBase):
|
|||
class OnnxDAG:
|
||||
"""ONNX model as a directed acyclic graph (DAG)."""
|
||||
|
||||
def __init__(self, model: "ModelProto"):
|
||||
def __init__(self, model: "ModelProto", only_main_graph: bool = False):
|
||||
self.model = model
|
||||
self.graphs = self.get_all_graphs(self.model)
|
||||
self.graphs = self.get_all_graphs(self.model, only_main_graph)
|
||||
self.nodes: Dict[str, OnnxNode] = {}
|
||||
self.ios: Dict[str, OnnxIO] = {}
|
||||
self.connections = defaultdict(list)
|
||||
|
@ -95,12 +95,16 @@ class OnnxDAG:
|
|||
self._process_node(node, self.nodes, self.ios, self.connections, idx)
|
||||
|
||||
@staticmethod
|
||||
def get_all_graphs(model: "ModelProto") -> List[GraphProto]:
|
||||
def get_all_graphs(model: "ModelProto", only_main_graph: bool = False) -> List[GraphProto]:
|
||||
"""Get all graphs in the model.
|
||||
|
||||
:param model: ONNX model.
|
||||
:param only_main_graph: whether to return only the main graph.
|
||||
:return: list of graphs in the model.
|
||||
"""
|
||||
if only_main_graph:
|
||||
return [model.graph]
|
||||
|
||||
all_graphs = []
|
||||
graph_queue = [model.graph]
|
||||
while graph_queue:
|
||||
|
@ -203,8 +207,7 @@ class OnnxDAG:
|
|||
raise ValueError(f"Output {o} is already connected to another node.")
|
||||
ios[o].source = name
|
||||
for destination in ios[o].destination:
|
||||
if destination != SpecialOutput.OUTPUT:
|
||||
connections[name].append(destination)
|
||||
connections[name].append(destination)
|
||||
|
||||
def add_input(self, input_proto: ValueInfoProto, graph_idx: int, keep_initializer: bool = False):
|
||||
"""Add an input to the graph.
|
||||
|
@ -224,7 +227,7 @@ class OnnxDAG:
|
|||
"""
|
||||
self._add_special_input(initializer, graph_idx, SpecialInput.INITIALIZER, keep_input)
|
||||
|
||||
def add_value_info(self, value_info: ValueInfoProto, graph_idx: int):
|
||||
def add_value_info(self, value_info: ValueInfoProto, graph_idx: int, overwrite: bool = False):
|
||||
"""Add a value info to the graph.
|
||||
|
||||
:param value_info: ValueInfoProto of the value info.
|
||||
|
@ -236,8 +239,18 @@ class OnnxDAG:
|
|||
self.ios[name] = OnnxIO(proto=[value_info], graph_idx=graph_idx)
|
||||
return
|
||||
|
||||
assert not self.ios[name].proto, f"Value info for {name} already exists in the graph."
|
||||
self.ios[name].proto.append(value_info)
|
||||
assert (
|
||||
overwrite or not self.ios[name].proto
|
||||
), f"Value info for {name} already exists in the graph but overwrite is False."
|
||||
self.ios[name].proto = [value_info]
|
||||
|
||||
def is_io(self, io_name: str) -> bool:
|
||||
"""Check if an input/output exists in the graph.
|
||||
|
||||
:param io_name: name of the input/output.
|
||||
:return: True if the input/output exists.
|
||||
"""
|
||||
return io_name in self.ios
|
||||
|
||||
def _add_special_input(
|
||||
self,
|
||||
|
@ -463,21 +476,37 @@ class OnnxDAG:
|
|||
"""
|
||||
return self.nodes[node_name].op_type
|
||||
|
||||
def get_node_inputs(self, node_name: str) -> List[str]:
|
||||
def get_node_proto(self, node_name: str) -> NodeProto:
|
||||
"""Get the node proto.
|
||||
|
||||
:param node_name: name of the node.
|
||||
:return: NodeProto object.
|
||||
"""
|
||||
return self.nodes[node_name].proto
|
||||
|
||||
def get_node_inputs(self, node_name: str, skip_empty_io: bool = False) -> List[str]:
|
||||
"""Get the input names of a node.
|
||||
|
||||
:param node_name: name of the node.
|
||||
:param skip_empty_io: whether to skip empty inputs.
|
||||
:return: list of input names.
|
||||
"""
|
||||
return list(self.nodes[node_name].inputs)
|
||||
inputs = self.nodes[node_name].inputs
|
||||
if skip_empty_io:
|
||||
inputs = filter(lambda i: i != "", inputs)
|
||||
return list(inputs)
|
||||
|
||||
def get_node_outputs(self, node_name: str) -> List[str]:
|
||||
def get_node_outputs(self, node_name: str, skip_empty_io: bool = False) -> List[str]:
|
||||
"""Get the output names of a node.
|
||||
|
||||
:param node_name: name of the node.
|
||||
:param skip_empty_io: whether to skip empty outputs.
|
||||
:return: list of output names.
|
||||
"""
|
||||
return list(self.nodes[node_name].outputs)
|
||||
outputs = self.nodes[node_name].outputs
|
||||
if skip_empty_io:
|
||||
outputs = filter(lambda o: o != "", outputs)
|
||||
return list(outputs)
|
||||
|
||||
def get_node_attributes(self, node_name: str) -> Dict[str, Any]:
|
||||
"""Get the attributes of a node.
|
||||
|
@ -511,6 +540,19 @@ class OnnxDAG:
|
|||
"""
|
||||
return SpecialInput.is_initializer(self.ios[io_name].source)
|
||||
|
||||
def is_constant_input(self, io_name: str, allow_input_initializer: bool = False) -> bool:
|
||||
"""Check if an input/output comes from an initializer or a constant node.
|
||||
|
||||
:param io_name: name of the input/output.
|
||||
:param allow_input_initializer: whether to consider input_initializer as a constant input.
|
||||
:return: True if the input/output is a constant input.
|
||||
"""
|
||||
source = self.ios[io_name].source
|
||||
if source in self.nodes:
|
||||
return self.get_node_op_type(source) == "Constant"
|
||||
|
||||
return (source == SpecialInput.INITIALIZER) or (allow_input_initializer and SpecialInput.is_initializer(source))
|
||||
|
||||
def get_graph_idx(self, name: str) -> int:
|
||||
"""Get the index of the graph containing the input/output or node."""
|
||||
if name in self.ios:
|
||||
|
@ -544,6 +586,17 @@ class OnnxDAG:
|
|||
if self.is_output(io_name):
|
||||
return
|
||||
self.ios[io_name].destination.append(SpecialOutput.OUTPUT)
|
||||
self.connections[self.get_producer(io_name)].append(SpecialOutput.OUTPUT)
|
||||
|
||||
def remove_output(self, io_name: str):
|
||||
"""Remove an output from an input/output.
|
||||
|
||||
:param io_name: name of the input/output.
|
||||
"""
|
||||
if not self.is_output(io_name):
|
||||
return
|
||||
self.ios[io_name].destination.remove(SpecialOutput.OUTPUT)
|
||||
self.connections[self.get_producer(io_name)].remove(SpecialOutput.OUTPUT)
|
||||
|
||||
def get_producer(self, io_name: str) -> str:
|
||||
"""Get the producer of an input/output.
|
||||
|
@ -553,16 +606,43 @@ class OnnxDAG:
|
|||
"""
|
||||
return self.ios[io_name].source
|
||||
|
||||
def get_consumers(self, node_name: str) -> List[str]:
|
||||
def get_consumers(self, node_name: str, return_special_outputs: bool = False) -> List[str]:
|
||||
"""Get the consumers of a node.
|
||||
|
||||
:param node_name: name of the node. It can also be an input or initializer.
|
||||
|
||||
:return: list of names of nodes that consume one/more outputs of the node.
|
||||
"""
|
||||
if node_name in self.ios and SpecialInput.is_special_input(self.ios[node_name].source):
|
||||
return list(self.ios[node_name].destination)
|
||||
if node_name in self.ios:
|
||||
consumers = self.ios[node_name].destination
|
||||
else:
|
||||
consumers = self.connections[node_name]
|
||||
|
||||
return list(self.connections[node_name])
|
||||
if return_special_outputs:
|
||||
return list(consumers)
|
||||
|
||||
return list(filter(lambda c: c != SpecialOutput.OUTPUT, consumers))
|
||||
|
||||
def get_parents(self, node_name: str, return_special_inputs: bool = False) -> List[str]:
|
||||
"""Get the parents of a node.
|
||||
|
||||
:param node_name: name of the node.
|
||||
:return: list of names of nodes that produce one/more inputs of the node.
|
||||
"""
|
||||
parents = [self.ios[i].source for i in self.nodes[node_name].inputs]
|
||||
|
||||
if return_special_inputs:
|
||||
return parents
|
||||
|
||||
return list(filter(lambda p: not SpecialInput.is_special_input(p), parents))
|
||||
|
||||
def is_input_consumer(self, node_name: str) -> bool:
|
||||
"""Check if a node is an input consumer.
|
||||
|
||||
:param node_name: name of the node.
|
||||
:return: True if the node consumes one/more inputs that are also model inputs.
|
||||
"""
|
||||
return any(SpecialInput.is_special_input(p) for p in self.get_parents(node_name, return_special_inputs=True))
|
||||
|
||||
def is_output_producer(self, node_name: str) -> bool:
|
||||
"""Check if a node is an output producer.
|
||||
|
@ -570,7 +650,7 @@ class OnnxDAG:
|
|||
:param node_name: name of the node.
|
||||
:return: True if the node produces one/more outputs that are also model outputs.
|
||||
"""
|
||||
return any(SpecialOutput.OUTPUT in self.ios[o].destination for o in self.nodes[node_name].outputs)
|
||||
return SpecialOutput.OUTPUT in self.get_consumers(node_name, return_special_outputs=True)
|
||||
|
||||
def _topological_sort_util(self, v: str, visited: Set[str], order: List[str]):
|
||||
"""Do depth-first search starting from node v.
|
||||
|
@ -683,10 +763,11 @@ class OnnxDAG:
|
|||
logger.debug("Removed %d Identity nodes", len(nodes_to_remove))
|
||||
|
||||
@classmethod
|
||||
def from_model_path(cls, model_path: Union[str, Path]) -> "OnnxDAG":
|
||||
def from_model_path(cls, model_path: Union[str, Path], only_main_graph: bool = False) -> "OnnxDAG":
|
||||
"""Load an ONNX model and create an self.
|
||||
|
||||
:param model_path: path to the ONNX model.
|
||||
:param only_main_graph: whether to create a DAG with only the main graph.
|
||||
:return: OnnxDAG object.
|
||||
"""
|
||||
return cls(onnx.load(model_path))
|
||||
return cls(onnx.load(model_path), only_main_graph=only_main_graph)
|
||||
|
|
|
@ -0,0 +1,280 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
|
||||
from olive.hardware.accelerator import AcceleratorSpec
|
||||
from olive.model import CompositeModelHandler, ONNXModelHandler
|
||||
from olive.model.utils import resolve_onnx_path
|
||||
from olive.passes import Pass
|
||||
from olive.passes.onnx.common import get_external_data_config, model_proto_to_olive_model
|
||||
from olive.passes.onnx.onnx_dag import OnnxDAG
|
||||
from olive.passes.pass_config import PassConfigParam
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SplitModel(Pass):
|
||||
@classmethod
|
||||
def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassConfigParam]:
|
||||
return {
|
||||
"include_all_nodes": PassConfigParam(
|
||||
type_=bool,
|
||||
default_value=True,
|
||||
description=(
|
||||
"Include all nodes in the split model. Nodes outside the splits or before the first split will be"
|
||||
" assigned to the first split. Nodes after the last split will be assigned to the last split. If"
|
||||
" False, these nodes will not be included in the split models."
|
||||
),
|
||||
),
|
||||
**get_external_data_config(),
|
||||
}
|
||||
|
||||
def _run_for_config(
|
||||
self, model: ONNXModelHandler, config: Dict[str, Any], output_model_path: str
|
||||
) -> CompositeModelHandler:
|
||||
model_proto = model.load_model()
|
||||
|
||||
split_assignments = None
|
||||
for metadata_prop in model_proto.metadata_props:
|
||||
if metadata_prop.key == "split_assignments":
|
||||
split_assignments = {
|
||||
key: int(value)
|
||||
for key, value in (assignment.split("=") for assignment in metadata_prop.value.split(";"))
|
||||
}
|
||||
break
|
||||
# TODO(jambayk): Should we allow split assignments in the model attributes too?
|
||||
if not split_assignments:
|
||||
raise ValueError("No split assignments found in the model metadata")
|
||||
|
||||
# TODO(jambayk): Make this more generic, for now only assume transformers layers are split
|
||||
# so depth of namespace is same for all split assignments
|
||||
num_splits = len(np.unique(list(split_assignments.values())))
|
||||
namespace_depth = len(next(iter(split_assignments)).split("."))
|
||||
|
||||
# create a dag for the model, won't split nested graphs
|
||||
dag = OnnxDAG(model_proto, only_main_graph=True)
|
||||
dag.remove_identity_nodes()
|
||||
# empy dags for each split
|
||||
split_proto = onnx.ModelProto(
|
||||
ir_version=model_proto.ir_version,
|
||||
opset_import=model_proto.opset_import,
|
||||
producer_name="olive",
|
||||
graph=onnx.GraphProto(name=model_proto.graph.name),
|
||||
)
|
||||
split_dags = [OnnxDAG(deepcopy(split_proto)) for _ in range(num_splits)]
|
||||
|
||||
# go through the nodes in topological order
|
||||
node_order = dag.topological_sort()
|
||||
node_assignments = {}
|
||||
constant_nodes = set()
|
||||
for node_name in node_order:
|
||||
# will handle constant nodes laters
|
||||
if dag.get_node_op_type(node_name) == "Constant":
|
||||
constant_nodes.add(node_name)
|
||||
continue
|
||||
|
||||
name_components = node_name.replace("/", ".").lstrip(".").split(".")
|
||||
namespace = ".".join(name_components[:namespace_depth])
|
||||
if namespace in split_assignments:
|
||||
node_assignments[node_name] = split_assignments[namespace]
|
||||
|
||||
# what is the next closest split, if not assigned to a split
|
||||
next_split = deepcopy(node_assignments)
|
||||
# already have a topological order, so we will go from the bottom up
|
||||
for node_name in node_order[::-1]:
|
||||
# constants cannot be children of node
|
||||
if node_name in node_assignments or node_name in constant_nodes:
|
||||
continue
|
||||
|
||||
child_splits = [
|
||||
next_split[child_name]
|
||||
for child_name in dag.get_consumers(node_name)
|
||||
if next_split[child_name] is not None
|
||||
]
|
||||
if child_splits:
|
||||
next_split[node_name] = min(child_splits)
|
||||
else:
|
||||
next_split[node_name] = None
|
||||
|
||||
if config["include_all_nodes"]:
|
||||
for node_name in node_order:
|
||||
if node_name in node_assignments or node_name in constant_nodes:
|
||||
continue
|
||||
|
||||
# parent has a split - after last split or the before/outside parent has been assigned
|
||||
parent_splits = [
|
||||
node_assignments[parent_name]
|
||||
for parent_name in dag.get_parents(node_name)
|
||||
if parent_name in node_assignments
|
||||
]
|
||||
if parent_splits:
|
||||
node_assignments[node_name] = max(parent_splits)
|
||||
continue
|
||||
|
||||
# before the first split
|
||||
if next_split[node_name] is not None:
|
||||
node_assignments[node_name] = next_split[node_name]
|
||||
continue
|
||||
|
||||
# outside the splits
|
||||
node_assignments[node_name] = 0
|
||||
else:
|
||||
# handle unassigned nodes that are:
|
||||
# - between splits: assign to the split of the parent node
|
||||
# - between constant/initializer and splits: assign the next split
|
||||
for node_name in node_order:
|
||||
if node_name in node_assignments or node_name in constant_nodes:
|
||||
continue
|
||||
|
||||
# after the last split
|
||||
if next_split[node_name] is None:
|
||||
continue
|
||||
|
||||
# between splits
|
||||
parent_splits = [
|
||||
node_assignments[parent_name]
|
||||
for parent_name in dag.get_parents(node_name)
|
||||
if parent_name in node_assignments
|
||||
]
|
||||
if parent_splits:
|
||||
node_assignments[node_name] = max(parent_splits)
|
||||
continue
|
||||
|
||||
# between constant/initializer and splits
|
||||
if all(dag.is_constant_input(input_name) for input_name in dag.get_node_inputs(node_name)):
|
||||
node_assignments[node_name] = next_split[node_name]
|
||||
|
||||
# handle cast nodes of the from:
|
||||
# - Input -> Cast -> Split
|
||||
# - Split -> Cast -> Output
|
||||
for node_name in node_order:
|
||||
if (
|
||||
node_name in node_assignments
|
||||
or dag.get_node_op_type(node_name) != "Cast"
|
||||
# only one consumer (model output or another node)
|
||||
or len(dag.get_consumers(node_name, True)) != 1
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
dag.is_input_consumer(node_name)
|
||||
and (consumer := dag.get_consumers(node_name, True)[0]) in node_assignments
|
||||
):
|
||||
node_assignments[node_name] = node_assignments[consumer]
|
||||
elif (
|
||||
parent_name := dag.get_parents(node_name, True)[0]
|
||||
) in node_assignments and dag.is_output_producer(node_name):
|
||||
node_assignments[node_name] = node_assignments[parent_name]
|
||||
|
||||
# handle constant nodes, will add a copy of the constant to each split
|
||||
for node_name in constant_nodes:
|
||||
splits = set()
|
||||
for consumer in dag.get_consumers(node_name):
|
||||
if consumer in node_assignments:
|
||||
splits.add(node_assignments[consumer])
|
||||
if splits:
|
||||
node_assignments[node_name] = list(splits)
|
||||
|
||||
# add the nodes to the split dags
|
||||
# keep track of missing value info for inputs to the split dags
|
||||
missing_vi = defaultdict(list)
|
||||
for node_name in node_order:
|
||||
split_id = node_assignments.get(node_name)
|
||||
if split_id is None:
|
||||
continue
|
||||
|
||||
if not isinstance(split_id, list):
|
||||
# no need to worry about inputs for list split_id since it's only for constants
|
||||
split_dag = split_dags[split_id]
|
||||
# add the inputs to the nodes if not already present
|
||||
for input_name in dag.get_node_inputs(node_name):
|
||||
if not input_name:
|
||||
# optional input left as ""
|
||||
continue
|
||||
# already added
|
||||
if split_dag.is_io(input_name):
|
||||
continue
|
||||
|
||||
io = dag.get_io(input_name)
|
||||
|
||||
# main graph inputs and/or initializers
|
||||
if dag.is_input(input_name) or dag.is_initializer(input_name):
|
||||
if dag.is_input(input_name):
|
||||
split_dags[split_id].add_input(io.proto[0], 0, True)
|
||||
if dag.is_initializer(input_name):
|
||||
split_dags[split_id].add_initializer(io.proto[-1], 0, True)
|
||||
continue
|
||||
|
||||
# cross split inputs
|
||||
proto = io.proto[0] if io.proto else None
|
||||
if not proto:
|
||||
# missing value info
|
||||
missing_vi[input_name].append(split_id)
|
||||
proto = onnx.helper.make_empty_tensor_value_info(input_name)
|
||||
split_dag.add_input(proto, 0)
|
||||
|
||||
# add the node to the split dag
|
||||
split_dag.add_node(dag.get_node_proto(node_name), 0)
|
||||
|
||||
# process the node outputs
|
||||
for output_name in dag.get_node_outputs(node_name):
|
||||
if not output_name:
|
||||
# optional output left as ""
|
||||
continue
|
||||
|
||||
# mark as output if any consumer is not in the split
|
||||
is_output = False
|
||||
for consumer in dag.get_consumers(output_name, True):
|
||||
if node_assignments.get(consumer) != split_id:
|
||||
split_dag.make_output(output_name)
|
||||
is_output = True
|
||||
break
|
||||
|
||||
# add vi for the outputs
|
||||
io = dag.get_io(output_name)
|
||||
if io.proto:
|
||||
split_dag.add_value_info(io.proto[0], 0)
|
||||
elif is_output:
|
||||
# missing value info
|
||||
missing_vi[output_name].append(split_id)
|
||||
else:
|
||||
# add the constant to each split
|
||||
for idx in split_id:
|
||||
split_dags[idx].add_node(dag.get_node_proto(node_name), 0)
|
||||
|
||||
if missing_vi:
|
||||
logger.debug("Missing value info for some io. Using onnxruntime shape inference to infer them.")
|
||||
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
|
||||
|
||||
# should we just use the same model proto? might modify dynamic shapes of existing value infos
|
||||
# if this becomes an issue replace with a newly loaded model proto
|
||||
shape_inferred_proto = SymbolicShapeInference.infer_shapes(model_proto, auto_merge=True)
|
||||
shape_inferred_dag = OnnxDAG(shape_inferred_proto, only_main_graph=True)
|
||||
|
||||
for input_name, split_ids in missing_vi.items():
|
||||
io = shape_inferred_dag.get_io(input_name)
|
||||
if not io.proto:
|
||||
raise ValueError(f"Missing value info for input {input_name} for split {split_id}")
|
||||
for idx in split_ids:
|
||||
split_dags[idx].add_value_info(io.proto[0], 0, overwrite=True)
|
||||
|
||||
component_models = []
|
||||
component_names = []
|
||||
for i, split_dag in enumerate(split_dags):
|
||||
split_name = f"split_{i}"
|
||||
split_dir = Path(output_model_path).with_suffix("") / split_name
|
||||
split_path = resolve_onnx_path(split_dir, f"{split_name}.onnx")
|
||||
split_dag.update()
|
||||
component_models.append(model_proto_to_olive_model(split_dag.model, split_path, config))
|
||||
component_names.append(split_name)
|
||||
|
||||
return CompositeModelHandler(component_models, component_names)
|
|
@ -0,0 +1,76 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from olive.common.hf.mappings import MODELS_TO_LAYERS_MAPPING
|
||||
from olive.common.utils import get_attr
|
||||
from olive.hardware.accelerator import AcceleratorSpec
|
||||
from olive.model import HfModelHandler, PyTorchModelHandler
|
||||
from olive.passes import Pass
|
||||
from olive.passes.pass_config import PassConfigParam
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CaptureSplitInfo(Pass):
|
||||
"""Capture the split information of the model layers. Only splits the transformer layers."""
|
||||
|
||||
@classmethod
|
||||
def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassConfigParam]:
|
||||
return {
|
||||
"num_splits": PassConfigParam(
|
||||
type_=int,
|
||||
required=True,
|
||||
description="Number of splits to divide the model layers into.",
|
||||
),
|
||||
"block_to_split": PassConfigParam(
|
||||
type_=str,
|
||||
default_value=None,
|
||||
description=(
|
||||
"Name of the model block to split. Children of the block will be divided into the splits. For"
|
||||
" supported transformers models, the default value is the transformers layer block name. Refer to"
|
||||
" olive.common.hf.mappings.MODELS_TO_LAYERS_MAPPING for supported models."
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
def _run_for_config(
|
||||
self, model: Union[HfModelHandler, PyTorchModelHandler], config: Dict[str, Any], output_model_path: str
|
||||
) -> Union[HfModelHandler, PyTorchModelHandler]:
|
||||
block_to_split = config["block_to_split"]
|
||||
# check for None specifically since "" is a valid value
|
||||
if block_to_split is None and isinstance(model, HfModelHandler):
|
||||
model_type = model.get_hf_model_type()
|
||||
block_to_split = MODELS_TO_LAYERS_MAPPING.get(model_type, None)
|
||||
if block_to_split is None:
|
||||
raise ValueError("block_to_split is not set and could not be inferred. Please set it manually.")
|
||||
|
||||
block_members = []
|
||||
# we could get the number of layers for hf model from the model attributes
|
||||
# but will just load the model to make the logic simple for now
|
||||
# consider loading with meta device to avoid loading the weights
|
||||
loaded_model = model.load_model(cache_model=False)
|
||||
block = get_attr(loaded_model, block_to_split)
|
||||
if block is None:
|
||||
raise ValueError(f"block_to_split {block_to_split} not found in model.")
|
||||
for child_name, _ in block.named_children():
|
||||
block_members.append(child_name)
|
||||
|
||||
split_assignments = {}
|
||||
for split_idx, split_members in enumerate(np.array_split(block_members, config["num_splits"])):
|
||||
for child_name in split_members:
|
||||
split_assignments[f"{block_to_split}.{child_name}".lstrip(".")] = split_idx
|
||||
|
||||
# create a copy of the iput model and add the split assignments as a new attribute
|
||||
model.model = None
|
||||
output_model = deepcopy(model)
|
||||
output_model.model_attributes = model_attributes = output_model.model_attributes or {}
|
||||
model_attributes["split_assignments"] = split_assignments
|
||||
|
||||
return output_model
|
|
@ -12,7 +12,7 @@ from typing import Any, Dict, List, Union
|
|||
import torch
|
||||
|
||||
from olive.common.config_utils import validate_config
|
||||
from olive.common.hf.mappings import MODEL_INSIDE_LAYER_MODULES, MODEL_LAYERS_BLOCK_NAME, MODEL_OUTSIDE_LAYER_MODULES
|
||||
from olive.common.hf.mappings import MODEL_INSIDE_LAYER_MODULES, MODEL_OUTSIDE_LAYER_MODULES, MODELS_TO_LAYERS_MAPPING
|
||||
from olive.data.config import DataConfig
|
||||
from olive.hardware.accelerator import AcceleratorSpec
|
||||
from olive.model import HfModelHandler, PyTorchModelHandler
|
||||
|
@ -161,7 +161,7 @@ class GptqQuantizer(Pass):
|
|||
fields_to_set = {
|
||||
"outside_layer_modules": MODEL_OUTSIDE_LAYER_MODULES,
|
||||
"inside_layer_modules": MODEL_INSIDE_LAYER_MODULES,
|
||||
"layers_block_name": MODEL_LAYERS_BLOCK_NAME,
|
||||
"layers_block_name": MODELS_TO_LAYERS_MAPPING,
|
||||
}
|
||||
for key, value in fields_to_set.items():
|
||||
if config[key]:
|
||||
|
|
|
@ -116,7 +116,7 @@ def input_model_info_fixture(tmp_path_factory):
|
|||
|
||||
|
||||
@pytest.mark.parametrize("model_type", [None, "float", "int4"])
|
||||
def test_model_has_adapters(tmp_path, input_model_info, model_type):
|
||||
def test_model_has_adapters(input_model_info, model_type):
|
||||
if model_type is None:
|
||||
assert not model_has_adapters(get_onnx_model().model_path)
|
||||
else:
|
||||
|
|
|
@ -0,0 +1,129 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import pytest
|
||||
|
||||
from olive.model import CompositeModelHandler, HfModelHandler, ONNXModelHandler
|
||||
from olive.passes.olive_pass import create_pass_from_dict
|
||||
from olive.passes.onnx.conversion import OnnxConversion
|
||||
from olive.passes.onnx.split import SplitModel
|
||||
from olive.passes.onnx.transformer_optimization import OrtTransformersOptimization
|
||||
|
||||
|
||||
# TODO(jambayk): Add model builder and qdq models to this test
|
||||
@pytest.fixture(name="input_model_info", scope="module")
|
||||
def input_model_info_fixture(tmp_path_factory):
|
||||
# this tmp_path exists for the duration of the test session
|
||||
# module is scope is used to ensure that the fixture is only created once
|
||||
tmp_path = tmp_path_factory.mktemp("test-split-model")
|
||||
|
||||
# store onnx models for use in tests
|
||||
all_models = {}
|
||||
|
||||
# input model
|
||||
input_model = HfModelHandler(
|
||||
model_path="katuni4ka/tiny-random-phi3",
|
||||
load_kwargs={"trust_remote_code": False, "revision": "585361abfee667f3c63f8b2dc4ad58405c4e34e2"},
|
||||
model_attributes={"split_assignments": {"model.layers.0": 0, "model.layers.1": 1}},
|
||||
)
|
||||
|
||||
# conversion fp32
|
||||
all_models["convert_fp32"] = create_pass_from_dict(
|
||||
OnnxConversion, {"torch_dtype": "float32"}, disable_search=True
|
||||
).run(input_model, tmp_path / "convert_fp32")
|
||||
# transformers opt fp32
|
||||
all_models["opt_fp32"] = create_pass_from_dict(
|
||||
OrtTransformersOptimization, {"model_type": "bert", "opt_level": 0}, disable_search=True
|
||||
).run(all_models["convert_fp32"], tmp_path / "opt_fp32")
|
||||
# transformers opt fp16
|
||||
all_models["opt_fp16"] = create_pass_from_dict(
|
||||
OrtTransformersOptimization,
|
||||
{"model_type": "bert", "opt_level": 0, "float16": True, "keep_io_types": False},
|
||||
disable_search=True,
|
||||
).run(all_models["convert_fp32"], tmp_path / "opt_fp16")
|
||||
# transformers opt fp16 with keep_io_types
|
||||
all_models["opt_fp16_keep_io_types"] = create_pass_from_dict(
|
||||
OrtTransformersOptimization,
|
||||
{"model_type": "bert", "opt_level": 0, "float16": True, "keep_io_types": True},
|
||||
disable_search=True,
|
||||
).run(all_models["convert_fp32"], tmp_path / "opt_fp16_keep_io_types")
|
||||
|
||||
return all_models
|
||||
|
||||
|
||||
def common_check(tmp_path, input_model_info, model_type, include_all_nodes):
|
||||
input_model = input_model_info[model_type]
|
||||
|
||||
split_model = create_pass_from_dict(SplitModel, {"include_all_nodes": include_all_nodes}, disable_search=True).run(
|
||||
input_model, tmp_path
|
||||
)
|
||||
|
||||
assert isinstance(split_model, CompositeModelHandler)
|
||||
components = list(split_model.model_components)
|
||||
assert len(components) == 2
|
||||
assert all(isinstance(component, ONNXModelHandler) for component in components)
|
||||
|
||||
# check that the split models can be loaded
|
||||
# TODO(jambayk): Consider running the full model and comparing the outputs with the splits
|
||||
# this is more involved. have to modify the input model to create outputs for the split outputs
|
||||
for component in components:
|
||||
component.prepare_session(execution_providers="CPUExecutionProvider")
|
||||
|
||||
# check that the splits have the expected kv inputs
|
||||
for split_idx in range(2):
|
||||
io_config = components[split_idx].io_config
|
||||
|
||||
expected_kv_inputs = [f"past_key_values.{split_idx}.key", f"past_key_values.{split_idx}.value"]
|
||||
if not include_all_nodes and model_type == "opt_fp16_keep_io_types":
|
||||
# key is used before the firt split too
|
||||
expected_kv_inputs.remove(f"past_key_values.{split_idx}.key")
|
||||
assert set(expected_kv_inputs) <= set(io_config["input_names"])
|
||||
|
||||
expected_kv_outputs = [f"present.{split_idx}.key", f"present.{split_idx}.value"]
|
||||
assert set(expected_kv_outputs) <= set(io_config["output_names"])
|
||||
|
||||
return input_model, components
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_type",
|
||||
["convert_fp32", "opt_fp32", "opt_fp16", "opt_fp16_keep_io_types"],
|
||||
)
|
||||
def test_split_model_all_nodes(tmp_path, input_model_info, model_type):
|
||||
input_model, components = common_check(tmp_path, input_model_info, model_type, True)
|
||||
input_io_config = input_model.io_config
|
||||
|
||||
# input of first split must be a subset of the input of the input model
|
||||
assert set(components[0].io_config["input_names"]) <= set(input_io_config["input_names"])
|
||||
# input of second split must be a subset of model input + first split output
|
||||
assert set(components[1].io_config["input_names"]) <= set(
|
||||
input_io_config["input_names"] + components[0].io_config["output_names"]
|
||||
)
|
||||
# output of first split not used by second split must be a subset of the output of the input model
|
||||
split_1_model_output = set(components[0].io_config["output_names"]) - set(components[1].io_config["input_names"])
|
||||
assert split_1_model_output <= set(input_io_config["output_names"])
|
||||
# output of second split must be a subset of the output of the input model
|
||||
split_2_model_output = set(components[1].io_config["output_names"])
|
||||
assert split_2_model_output <= set(input_io_config["output_names"])
|
||||
# all split model outputs must be the same as the input model outputs
|
||||
assert (split_1_model_output | split_2_model_output) == set(input_io_config["output_names"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_type", "expected_overlap"),
|
||||
[
|
||||
("convert_fp32", 3),
|
||||
("opt_fp32", 4),
|
||||
("opt_fp16", 4),
|
||||
("opt_fp16_keep_io_types", 4),
|
||||
],
|
||||
)
|
||||
def test_split_model_only_splits(tmp_path, input_model_info, model_type, expected_overlap):
|
||||
_, components = common_check(tmp_path, input_model_info, model_type, False)
|
||||
|
||||
# check that the splits have the expected overlap
|
||||
assert (
|
||||
len(set(components[0].io_config["output_names"]).intersection(set(components[1].io_config["input_names"])))
|
||||
== expected_overlap
|
||||
)
|
|
@ -0,0 +1,59 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from olive.model import HfModelHandler, PyTorchModelHandler
|
||||
from olive.passes.olive_pass import create_pass_from_dict
|
||||
from olive.passes.pytorch.capture_split_info import CaptureSplitInfo
|
||||
|
||||
|
||||
class CustomModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.before_layer = torch.nn.Linear(2, 4)
|
||||
self.layers = torch.nn.ModuleList([torch.nn.Linear(4, 4) for _ in range(4)])
|
||||
self.after_layer = torch.nn.Linear(4, 2)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.before_layer(x)
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return self.after_layer(x)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("input_model", "block_to_split", "num_splits", "split_assignments"),
|
||||
[
|
||||
(
|
||||
PyTorchModelHandler(model_loader=lambda _: CustomModel()),
|
||||
"layers",
|
||||
2,
|
||||
{"layers.0": 0, "layers.1": 0, "layers.2": 1, "layers.3": 1},
|
||||
),
|
||||
(
|
||||
PyTorchModelHandler(model_loader=lambda _: CustomModel()),
|
||||
"",
|
||||
3,
|
||||
{"before_layer": 0, "layers": 1, "after_layer": 2},
|
||||
),
|
||||
(
|
||||
HfModelHandler(model_path="hf-internal-testing/tiny-random-LlamaForCausalLM"),
|
||||
None,
|
||||
2,
|
||||
{"model.layers.0": 0, "model.layers.1": 1},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_capture_split_info(input_model, block_to_split, num_splits, split_assignments, tmp_path):
|
||||
config = {
|
||||
"num_splits": num_splits,
|
||||
}
|
||||
if block_to_split is not None:
|
||||
config["block_to_split"] = block_to_split
|
||||
p = create_pass_from_dict(CaptureSplitInfo, config, disable_search=True)
|
||||
|
||||
out = p.run(input_model, str(tmp_path))
|
||||
assert out.model_attributes["split_assignments"] == split_assignments
|
Загрузка…
Ссылка в новой задаче