phi2 conversion/optimization script (#19338)
### Description <!-- Describe your changes. --> This PR adds onnx conversion script for dynamo exported phi2, optimization script, and inference example script A readme file is added as documentation. https://github.com/microsoft/onnxruntime/tree/wangye/phi2_doc/onnxruntime/python/tools/transformers/models/phi2#readme ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> --------- Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
This commit is contained in:
Родитель
e6d3518db9
Коммит
aaf32fb1b1
|
@ -473,6 +473,9 @@ file(GLOB onnxruntime_python_transformers_models_llama_src CONFIGURE_DEPENDS
|
|||
file(GLOB onnxruntime_python_transformers_models_longformer_src CONFIGURE_DEPENDS
|
||||
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/longformer/*.py"
|
||||
)
|
||||
file(GLOB onnxruntime_python_transformers_models_phi2_src CONFIGURE_DEPENDS
|
||||
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/phi2/*.py"
|
||||
)
|
||||
file(GLOB onnxruntime_python_transformers_models_stable_diffusion_src CONFIGURE_DEPENDS
|
||||
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/stable_diffusion/*.py"
|
||||
)
|
||||
|
@ -543,6 +546,7 @@ add_custom_command(
|
|||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/gpt2
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/llama
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/longformer
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/phi2
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/stable_diffusion
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/t5
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/whisper
|
||||
|
@ -646,6 +650,9 @@ add_custom_command(
|
|||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${onnxruntime_python_transformers_models_longformer_src}
|
||||
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/longformer/
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${onnxruntime_python_transformers_models_phi2_src}
|
||||
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/phi2/
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${onnxruntime_python_transformers_models_stable_diffusion_src}
|
||||
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/stable_diffusion/
|
||||
|
|
|
@ -205,6 +205,7 @@ class SymbolicShapeInference:
|
|||
"GemmFastGelu": self._infer_GemmFastGelu,
|
||||
"GemmFloat8": self._infer_GemmFloat8,
|
||||
"GroupNorm": self._infer_GroupNorm,
|
||||
"GroupQueryAttention": self._infer_GroupQueryAttention,
|
||||
"SkipGroupNorm": self._infer_SkipGroupNorm,
|
||||
"LayerNormalization": self._infer_LayerNormalization,
|
||||
"LongformerAttention": self._infer_LongformerAttention,
|
||||
|
@ -471,6 +472,7 @@ class SymbolicShapeInference:
|
|||
"PythonOp",
|
||||
"MultiHeadAttention",
|
||||
"GroupNorm",
|
||||
"GroupQueryAttention",
|
||||
"SkipGroupNorm",
|
||||
"BiasSplitGelu",
|
||||
"BiasAdd",
|
||||
|
@ -2409,6 +2411,32 @@ class SymbolicShapeInference:
|
|||
def _infer_GroupNorm(self, node): # noqa: N802
|
||||
self._propagate_shape_and_type(node)
|
||||
|
||||
def _infer_GroupQueryAttention(self, node): # noqa: N802
|
||||
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
|
||||
|
||||
past_shape = self._try_get_shape(node, 3)
|
||||
if past_shape is not None:
|
||||
vi = self.known_vi_[node.output[1]]
|
||||
vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape))
|
||||
vi = self.known_vi_[node.output[2]]
|
||||
vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape))
|
||||
|
||||
if node.input[1] != "" and node.input[2] != "":
|
||||
self._propagate_shape_and_type(node, 0, 0)
|
||||
else:
|
||||
# combined qkv: (batch_size, sequence_length, num_heads * head_size + 2 * kv_num_heads * head_size)
|
||||
assert node.input[1] == "" and node.input[2] == ""
|
||||
num_heads = get_attribute(node, "num_heads")
|
||||
kv_num_heads = get_attribute(node, "kv_num_heads")
|
||||
query_shape = self._get_shape(node, 0)
|
||||
if query_shape is not None:
|
||||
hidden_size = query_shape[2]
|
||||
if isinstance(hidden_size, int):
|
||||
head_size = int(hidden_size / (num_heads + 2 * kv_num_heads))
|
||||
query_shape[2] = num_heads * head_size
|
||||
vi = self.known_vi_[node.output[0]]
|
||||
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, query_shape))
|
||||
|
||||
def _infer_SkipGroupNorm(self, node): # noqa: N802
|
||||
self._propagate_shape_and_type(node, 0, 0)
|
||||
if len(node.output) > 1:
|
||||
|
|
|
@ -0,0 +1,92 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import logging
|
||||
|
||||
import onnx
|
||||
|
||||
|
||||
class DynamoOnnxHelper:
|
||||
"""
|
||||
Helper class for processing ONNX models exported by torch Dynamo.
|
||||
"""
|
||||
|
||||
def __init__(self, model: onnx.ModelProto):
|
||||
self.model = model
|
||||
|
||||
def update_edges(self, edge_mapping: dict) -> None:
|
||||
"""
|
||||
Updates the edges in the model according to the given mapping.
|
||||
"""
|
||||
for node in self.model.graph.node:
|
||||
for i in range(len(node.input)):
|
||||
if node.input[i] in edge_mapping:
|
||||
node.input[i] = edge_mapping[node.input[i]]
|
||||
for i in range(len(node.output)):
|
||||
if node.output[i] in edge_mapping:
|
||||
node.output[i] = edge_mapping[node.output[i]]
|
||||
|
||||
for graph_input in self.model.graph.input:
|
||||
if graph_input.name in edge_mapping:
|
||||
graph_input.name = edge_mapping[graph_input.name]
|
||||
for graph_output in self.model.graph.output:
|
||||
if graph_output.name in edge_mapping:
|
||||
graph_output.name = edge_mapping[graph_output.name]
|
||||
|
||||
def unroll_function(self, func_name: str) -> None:
|
||||
"""
|
||||
Unrolls the function with the given name in the model.
|
||||
"""
|
||||
logging.info(f"Unrolling function {func_name}...")
|
||||
nodes_to_remove = []
|
||||
nodes_to_add = []
|
||||
edges_to_remove = []
|
||||
edges_to_add = []
|
||||
for node in self.model.graph.node:
|
||||
if node.op_type == func_name:
|
||||
nodes_to_remove.append(node)
|
||||
edges_to_remove.extend(list(node.input) + list(node.output))
|
||||
|
||||
func_to_remove = None
|
||||
for f in self.model.functions:
|
||||
if f.name == func_name:
|
||||
nodes_to_add.extend(list(f.node))
|
||||
edges_to_add.extend(list(f.input) + list(f.output))
|
||||
func_to_remove = f
|
||||
|
||||
assert len(edges_to_remove) == len(edges_to_add)
|
||||
|
||||
for node in nodes_to_remove:
|
||||
self.model.graph.node.remove(node)
|
||||
for node in nodes_to_add:
|
||||
self.model.graph.node.append(node)
|
||||
if func_to_remove is not None:
|
||||
self.model.functions.remove(func_to_remove)
|
||||
|
||||
edge_mapping = {}
|
||||
for i in range(len(edges_to_remove)):
|
||||
k = edges_to_remove[i]
|
||||
v = edges_to_add[i]
|
||||
if k != v:
|
||||
edge_mapping[k] = v
|
||||
|
||||
return self.update_edges(edge_mapping)
|
||||
|
||||
def remove_dropout_layer(self) -> None:
|
||||
"""
|
||||
Removes the dropout layer in the model.
|
||||
"""
|
||||
logging.info("Removing dropout layer...")
|
||||
edge_mapping = {}
|
||||
nodes_to_remove = []
|
||||
for node in self.model.graph.node:
|
||||
if node.op_type.find("Dropout") != -1:
|
||||
assert len(node.input) == 1
|
||||
assert len(node.output) == 1
|
||||
edge_mapping[node.output[0]] = node.input[0]
|
||||
nodes_to_remove.append(node)
|
||||
for node in nodes_to_remove:
|
||||
self.model.graph.node.remove(node)
|
||||
|
||||
self.update_edges(edge_mapping)
|
|
@ -174,6 +174,7 @@ def convert_float_to_float16(
|
|||
node_block_list=None,
|
||||
force_fp16_initializers=False,
|
||||
force_fp16_inputs=None,
|
||||
use_bfloat16_as_blocked_nodes_dtype=False,
|
||||
):
|
||||
"""Convert tensor float type in the input ONNX model to tensor float16.
|
||||
|
||||
|
@ -436,6 +437,7 @@ def convert_float_to_float16(
|
|||
node.input[i] = output_name
|
||||
break
|
||||
|
||||
accuracy_type = TensorProto.BFLOAT16 if use_bfloat16_as_blocked_nodes_dtype else TensorProto.FLOAT
|
||||
# process the nodes in block list that doesn't support tensor(float16)
|
||||
for node in node_list:
|
||||
# if input's name is in the value_info_list meaning input is tensor(float16) type,
|
||||
|
@ -450,10 +452,10 @@ def convert_float_to_float16(
|
|||
new_value_info.CopyFrom(value_info)
|
||||
output_name = node.name + "_input_cast_" + str(i)
|
||||
new_value_info.name = output_name
|
||||
new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT
|
||||
new_value_info.type.tensor_type.elem_type = accuracy_type
|
||||
# add Cast node (from tensor(float16) to tensor(float) before current node
|
||||
node_name = node.name + "_input_cast" + str(i)
|
||||
new_node = [helper.make_node("Cast", [input_name], [output_name], to=1, name=node_name)]
|
||||
new_node = [helper.make_node("Cast", [input_name], [output_name], to=accuracy_type, name=node_name)]
|
||||
model.graph.node.extend(new_node)
|
||||
# change current node's input name
|
||||
node.input[i] = output_name
|
||||
|
@ -469,7 +471,7 @@ def convert_float_to_float16(
|
|||
new_value_info.CopyFrom(value_info)
|
||||
input_name = node.name + "_output_cast_" + str(i)
|
||||
new_value_info.name = input_name
|
||||
new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT
|
||||
new_value_info.type.tensor_type.elem_type = accuracy_type
|
||||
# add Cast node (from tensor(float) to tensor(float16) after current node
|
||||
node_name = node.name + "_output_cast" + str(i)
|
||||
new_node = [helper.make_node("Cast", [input_name], [output], to=10, name=node_name)]
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
from argparse import ArgumentParser
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class AttentionMaskFormat:
|
||||
|
@ -19,6 +20,15 @@ class AttentionMaskFormat:
|
|||
NoMask = 3
|
||||
|
||||
|
||||
class AttentionOpType(Enum):
|
||||
Attention = "Attention"
|
||||
MultiHeadAttention = "MultiHeadAttention"
|
||||
GroupQueryAttention = "GroupQueryAttention"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
class FusionOptions:
|
||||
"""Options of fusion in graph optimization"""
|
||||
|
||||
|
@ -57,6 +67,8 @@ class FusionOptions:
|
|||
elif model_type == "vit":
|
||||
self.attention_mask_format = AttentionMaskFormat.NoMask
|
||||
|
||||
self.attention_op_type = None
|
||||
|
||||
# options for stable diffusion
|
||||
if model_type in ["unet", "vae", "clip"]:
|
||||
self.enable_nhwc_conv = True
|
||||
|
@ -76,6 +88,9 @@ class FusionOptions:
|
|||
def disable_attention_mask(self):
|
||||
self.attention_mask_format = AttentionMaskFormat.NoMask
|
||||
|
||||
def set_attention_op_type(self, attn_op_type: AttentionOpType):
|
||||
self.attention_op_type = attn_op_type
|
||||
|
||||
@staticmethod
|
||||
def parse(args):
|
||||
options = FusionOptions(args.model_type)
|
||||
|
|
|
@ -0,0 +1,119 @@
|
|||
# Phi2 Optimizations
|
||||
## Prerequisites
|
||||
A Linux machine for [TorchDynamo-based ONNX Exporter](https://pytorch.org/docs/stable/onnx.html#torchdynamo-based-onnx-exporter)\
|
||||
Install onnx, onnxscript and transformers by running
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
To export ONNX, PyTorch version 2.2.0 or higher is required. The [official website](https://pytorch.org/) offers packages compatible with CUDA 11.8 and 12.1. Please select the appropriate version according to your needs.
|
||||
\
|
||||
\
|
||||
**There are two options to run the conversion script:**\
|
||||
_From source:_
|
||||
```bash
|
||||
pip install onnxruntime-gpu==1.17.0 # or onnxruntime==1.17.0 if using cpu
|
||||
git clone git@github.com:microsoft/onnxruntime.git
|
||||
cd onnxruntime/onnxruntime/python/tools/transformers
|
||||
python -m models.phi2.convert_to_onnx -h
|
||||
```
|
||||
_From wheel:_ \
|
||||
Install [ORT nightly package](https://onnxruntime.ai/docs/install/#inference-install-table-for-all-languages)
|
||||
```bash
|
||||
python -m onnxruntime.transformers.models.phi2.convert_to_onnx -h
|
||||
```
|
||||
|
||||
## Export optimized phi2 onnx model for different scenarios
|
||||
**Export FP32 ONNX model for Nvidia GPUs** \
|
||||
_From source:_
|
||||
```
|
||||
python -m models.phi2.convert_to_onnx --fp32_gpu
|
||||
```
|
||||
_From wheel:_
|
||||
```
|
||||
python -m onnxruntime.transformers.models.phi2.convert_to_onnx --fp32_gpu
|
||||
```
|
||||
\
|
||||
**Export FP16 ONNX model for Nvidia GPUs** \
|
||||
_From source:_
|
||||
```
|
||||
python -m models.phi2.convert_to_onnx --fp16_gpu
|
||||
```
|
||||
_From wheel:_
|
||||
```
|
||||
python -m onnxruntime.transformers.models.phi2.convert_to_onnx --fp16_gpu
|
||||
```
|
||||
\
|
||||
**Export INT4 ONNX model for Nvidia GPUs** \
|
||||
_From source:_
|
||||
```
|
||||
python -m models.phi2.convert_to_onnx --int4_gpu
|
||||
```
|
||||
_From wheel:_
|
||||
```
|
||||
python -m onnxruntime.transformers.models.phi2.convert_to_onnx --int4_gpu
|
||||
```
|
||||
\
|
||||
**Export FP16 ONNX model for Nvidia GPUs with CUDA architecture SM=80~89** \
|
||||
_From source:_
|
||||
```
|
||||
python -m models.phi2.convert_to_onnx --fp16_gpu_sm8x
|
||||
```
|
||||
_From wheel:_
|
||||
```
|
||||
python -m onnxruntime.transformers.models.phi2.convert_to_onnx --fp16_gpu_sm8x
|
||||
```
|
||||
\
|
||||
**Export INT4 ONNX model for Nvidia GPUs with CUDA architecture SM=80~89** \
|
||||
_From source:_
|
||||
```
|
||||
python -m models.phi2.convert_to_onnx --int4_gpu_sm8x
|
||||
```
|
||||
_From wheel:_
|
||||
```
|
||||
python -m onnxruntime.transformers.models.phi2.convert_to_onnx --int4_gpu_sm8x
|
||||
```
|
||||
\
|
||||
**Export FP32 ONNX model for CPU** \
|
||||
_From source:_
|
||||
```
|
||||
python -m models.phi2.convert_to_onnx --fp32_cpu
|
||||
```
|
||||
_From wheel:_
|
||||
```
|
||||
python -m onnxruntime.transformers.models.phi2.convert_to_onnx --fp32_cpu
|
||||
```
|
||||
\
|
||||
**Export INT4 ONNX model for CPU** \
|
||||
_From source:_
|
||||
```
|
||||
python -m models.phi2.convert_to_onnx --int4_cpu
|
||||
```
|
||||
_From wheel:_
|
||||
```
|
||||
python -m onnxruntime.transformers.models.phi2.convert_to_onnx --int4_cpu
|
||||
```
|
||||
\
|
||||
**Export all at once** \
|
||||
_From source:_
|
||||
```
|
||||
python -m models.phi2.convert_to_onnx --fp32_cpu --int4_cpu --fp32_gpu --fp16_gpu --int4_gpu --fp16_gpu_sm8x --int4_gpu_sm8x
|
||||
```
|
||||
_From wheel:_
|
||||
```
|
||||
python -m onnxruntime.transformers.models.phi2.convert_to_onnx --fp32_cpu --int4_cpu --fp32_gpu --fp16_gpu --int4_gpu --fp16_gpu_sm8x --int4_gpu_sm8x
|
||||
```
|
||||
## Run example with ORT
|
||||
**(e.g) Export FP16 and INT4 ONNX models for Nvidia GPUs with CUDA architecture SM=80~89 and run examples.** \
|
||||
_From source:_
|
||||
```
|
||||
python -m models.phi2.convert_to_onnx --fp16_gpu_sm8x --int4_gpu_sm8x --run_example
|
||||
```
|
||||
_From wheel:_
|
||||
```
|
||||
python -m onnxruntime.transformers.models.phi2.convert_to_onnx --fp16_gpu_sm8x --int4_gpu_sm8x --run_example
|
||||
```
|
||||
The inference example currently supports all models running on CUDA.
|
||||
|
||||
## Limitations
|
||||
- TorchDynamo-based ONNX Exporter only supports Linux.
|
||||
- The program may not run as expected if the machine has limited memory. e.g Dynamo export may use ~11.6GB; Optimization may use ~4.5GB for each.
|
|
@ -0,0 +1,12 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
|
||||
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
if transformers_dir not in sys.path:
|
||||
sys.path.append(transformers_dir)
|
|
@ -0,0 +1,458 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import onnx
|
||||
import torch
|
||||
from benchmark_helper import Precision
|
||||
from fusion_options import AttentionOpType
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
|
||||
from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer
|
||||
|
||||
|
||||
class ConvertPhi2ToONNX:
|
||||
def __init__(
|
||||
self,
|
||||
device: torch.device,
|
||||
model_class: str = "microsoft/phi-2",
|
||||
cache_dir: str = "./cache",
|
||||
):
|
||||
self.model_class = model_class
|
||||
self.device = device
|
||||
self.cache_dir = cache_dir
|
||||
self.phi_config = AutoConfig.from_pretrained(self.model_class, trust_remote_code=True, cache_dir=self.cache_dir)
|
||||
self.phi_model = None
|
||||
self.batch_size = 2
|
||||
self.sequence_length = 8
|
||||
self.attn_op_type = None
|
||||
self.precision = None
|
||||
self.block_size = 16
|
||||
self.accuracy_level = None
|
||||
|
||||
def set_quantization_params(self, block_size: int, accuracy_level: int | None):
|
||||
self.block_size = block_size
|
||||
self.accuracy_level = accuracy_level
|
||||
|
||||
def init_attn_type_and_precision(self, attn_op_type: AttentionOpType, precision: Precision):
|
||||
self.attn_op_type = attn_op_type
|
||||
self.precision = precision
|
||||
|
||||
def erase_onnx_model(self, onnx_path: str) -> None:
|
||||
assert onnx_path.endswith(".onnx")
|
||||
if not os.path.exists(onnx_path):
|
||||
return
|
||||
|
||||
model = onnx.load_model(onnx_path, load_external_data=False)
|
||||
onnx_data_path = None
|
||||
for initializer in model.graph.initializer:
|
||||
if initializer.data_location == 1 and initializer.external_data[0].key == "location":
|
||||
onnx_data_path = "./" + initializer.external_data[0].value
|
||||
break
|
||||
logging.info(f"Erasing {onnx_path}...")
|
||||
os.remove(onnx_path)
|
||||
if onnx_data_path is not None:
|
||||
onnx_data_path = os.path.join(Path(onnx_path).parent, onnx_data_path)
|
||||
logging.info(f"Erasing {onnx_data_path}...")
|
||||
os.remove(onnx_data_path)
|
||||
|
||||
def get_phi2_torch_model(self):
|
||||
logging.info("Loading phi2 torch model...")
|
||||
if self.phi_model is not None:
|
||||
return
|
||||
self.phi_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_class, trust_remote_code=True, cache_dir=self.cache_dir
|
||||
)
|
||||
self.phi_model.eval()
|
||||
self.phi_model.to(self.device)
|
||||
|
||||
def get_phi2_torch_inputs(self, batch_size: int, sequence_length: int):
|
||||
input_ids = torch.randint(
|
||||
low=0,
|
||||
high=self.phi_config.vocab_size,
|
||||
size=(batch_size, sequence_length),
|
||||
dtype=torch.int64,
|
||||
device=self.device,
|
||||
)
|
||||
self.get_phi2_torch_model()
|
||||
torch_inputs = self.phi_model.prepare_inputs_for_generation(
|
||||
input_ids, past_key_values=self.phi_model(input_ids, use_cache=True)["past_key_values"]
|
||||
)
|
||||
return torch_inputs["input_ids"], torch_inputs["attention_mask"], torch_inputs["past_key_values"]
|
||||
|
||||
def dynamo_export(self, onnx_path: str):
|
||||
input_ids, attention_mask, past_key_values = self.get_phi2_torch_inputs(self.batch_size, self.sequence_length)
|
||||
self.phi_model(input_ids, attention_mask=attention_mask, past_key_values=past_key_values)
|
||||
|
||||
from torch._dynamo import config
|
||||
|
||||
config.capture_scalar_outputs = True
|
||||
|
||||
logging.info("Exporting Phi2 torch model to ONNX...")
|
||||
torch.onnx.dynamo_export(
|
||||
self.phi_model,
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
export_options=torch.onnx.ExportOptions(dynamic_shapes=True),
|
||||
).save(onnx_path)
|
||||
onnx.checker.check_model(onnx_path)
|
||||
onnx.shape_inference.infer_shapes_path(onnx_path)
|
||||
|
||||
def optimize_phi2_onnx(self, onnx_path: str, onnx_path_opt: str):
|
||||
from fusion_options import FusionOptions
|
||||
from optimizer import optimize_model
|
||||
|
||||
optimization_options = FusionOptions("phi")
|
||||
optimization_options.set_attention_op_type(self.attn_op_type)
|
||||
optimizer = optimize_model(
|
||||
onnx_path,
|
||||
model_type="phi",
|
||||
num_heads=self.phi_config.num_attention_heads,
|
||||
hidden_size=self.phi_config.hidden_size,
|
||||
opt_level=0,
|
||||
optimization_options=optimization_options,
|
||||
only_onnxruntime=False,
|
||||
)
|
||||
|
||||
fused_op_count = optimizer.get_fused_operator_statistics()
|
||||
if optimizer.is_fully_optimized(fused_op_count):
|
||||
logging.info("Model is fully optimized.")
|
||||
else:
|
||||
logging.info("Model is not fully optimized.")
|
||||
|
||||
if self.precision == Precision.FLOAT32:
|
||||
optimizer.save_model_to_file(onnx_path_opt, use_external_data_format=True)
|
||||
return
|
||||
|
||||
if (
|
||||
self.precision == Precision.FLOAT16 or self.precision == Precision.INT4
|
||||
) and self.attn_op_type != AttentionOpType.MultiHeadAttention:
|
||||
# We keep last three layers of Attention as float32 or bfloat16 to avoid overflow.
|
||||
node_block_list = [
|
||||
"GroupQueryAttention_29",
|
||||
"GroupQueryAttention_30",
|
||||
"GroupQueryAttention_31",
|
||||
"Attention_29",
|
||||
"Attention_30",
|
||||
"Attention_31",
|
||||
]
|
||||
logging.info("Converting onnx model to float16/bfloat16...")
|
||||
optimizer.convert_float_to_float16(
|
||||
keep_io_types=False,
|
||||
node_block_list=node_block_list,
|
||||
use_symbolic_shape_infer=True,
|
||||
use_bfloat16_as_blocked_nodes_dtype=self.attn_op_type == AttentionOpType.GroupQueryAttention,
|
||||
)
|
||||
logging.info("Converting onnx model to float16/bfloat16 done.")
|
||||
|
||||
if self.precision == Precision.FLOAT16:
|
||||
optimizer.save_model_to_file(onnx_path_opt, use_external_data_format=True)
|
||||
return
|
||||
else:
|
||||
assert self.precision == Precision.INT4
|
||||
quant = MatMul4BitsQuantizer(
|
||||
model=optimizer.model,
|
||||
block_size=self.block_size,
|
||||
is_symmetric=True,
|
||||
accuracy_level=self.accuracy_level,
|
||||
)
|
||||
quant.process()
|
||||
quant.model.save_model_to_file(onnx_path_opt, use_external_data_format=True)
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--fp32_cpu",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Generate fp32 ONNX model for CPU",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--int4_cpu",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Generate int4 ONNX model for CPU",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--fp32_gpu",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Generate fp32 ONNX model for Nvidia GPUs",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--fp16_gpu",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Generate fp16 ONNX model for Nvidia GPUs",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--int4_gpu",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Generate int4 ONNX model for Nvidia GPUs",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--fp16_gpu_sm8x",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Generate fp16 ONNX model for Nvidia GPUs with CUDA architecture SM=80~89",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--int4_gpu_sm8x",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Generate int4 ONNX model for Nvidia GPUs with CUDA architecture SM=80~89",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--overwrite",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Overwrite existing ONNX models",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
required=False,
|
||||
type=str,
|
||||
default="./cache",
|
||||
help="The cache directory for the pytorch model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--device_id",
|
||||
required=False,
|
||||
type=int,
|
||||
default=0,
|
||||
help="The device id for the pytorch model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--run_example",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Run ORT inference example",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--skip_export",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Skip exporting ONNX model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
help="The output directory for the ONNX models",
|
||||
default="phi2_onnx_models",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--block_size",
|
||||
required=False,
|
||||
default=16,
|
||||
type=int,
|
||||
help="Block size to quantize with. See https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py for details.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--int4_accuracy_level",
|
||||
required=False,
|
||||
type=int,
|
||||
help="Accuracy level of the 4-bit quantized MatMul computation. "
|
||||
"Refer to the MatMulNBits contrib op's 'accuracy_level' attribute for details "
|
||||
"(https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftmatmulnbits).",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_arguments()
|
||||
|
||||
device = torch.device("cuda", args.device_id) if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
converter = ConvertPhi2ToONNX(device, cache_dir=args.cache_dir)
|
||||
converter.set_quantization_params(args.block_size, args.int4_accuracy_level)
|
||||
|
||||
output_dir = args.output_dir
|
||||
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
original_onnx_path = os.path.join(output_dir, "phi2_original.onnx")
|
||||
|
||||
if not args.skip_export:
|
||||
if not os.path.exists(original_onnx_path) or args.overwrite:
|
||||
converter.dynamo_export(original_onnx_path)
|
||||
|
||||
model_type_to_args = {
|
||||
"fp32_cpu": (
|
||||
AttentionOpType.MultiHeadAttention,
|
||||
Precision.FLOAT32,
|
||||
os.path.join(output_dir, "phi2_decoder_fp32_cpu.onnx"),
|
||||
),
|
||||
"int4_cpu": (
|
||||
AttentionOpType.MultiHeadAttention,
|
||||
Precision.INT4,
|
||||
os.path.join(output_dir, "phi2_decoder_int4_cpu.onnx"),
|
||||
),
|
||||
"fp32_gpu": (
|
||||
AttentionOpType.Attention,
|
||||
Precision.FLOAT32,
|
||||
os.path.join(output_dir, "phi2_decoder_fp32_gpu.onnx"),
|
||||
),
|
||||
"fp16_gpu": (
|
||||
AttentionOpType.Attention,
|
||||
Precision.FLOAT16,
|
||||
os.path.join(output_dir, "phi2_decoder_fp16_gpu.onnx"),
|
||||
),
|
||||
"int4_gpu": (AttentionOpType.Attention, Precision.INT4, os.path.join(output_dir, "phi2_decoder_int4_gpu.onnx")),
|
||||
"fp16_gpu_sm8x": (
|
||||
AttentionOpType.GroupQueryAttention,
|
||||
Precision.FLOAT16,
|
||||
os.path.join(output_dir, "phi2_decoder_fp16_gpu_sm8x.onnx"),
|
||||
),
|
||||
"int4_gpu_sm8x": (
|
||||
AttentionOpType.GroupQueryAttention,
|
||||
Precision.INT4,
|
||||
os.path.join(output_dir, "phi2_decoder_int4_gpu_sm8x.onnx"),
|
||||
),
|
||||
}
|
||||
|
||||
if not args.skip_export:
|
||||
from multiprocessing import Process
|
||||
|
||||
def run_optimize_phi2_onnx(
|
||||
converter: ConvertPhi2ToONNX,
|
||||
original_onnx_path: str,
|
||||
attention_type: AttentionOpType,
|
||||
precision: Precision,
|
||||
optimized_onnx_path: str,
|
||||
):
|
||||
converter.init_attn_type_and_precision(attention_type, precision)
|
||||
converter.optimize_phi2_onnx(original_onnx_path, optimized_onnx_path)
|
||||
|
||||
processes = []
|
||||
if args.fp32_cpu:
|
||||
processes.append(
|
||||
Process(
|
||||
target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["fp32_cpu"])
|
||||
)
|
||||
)
|
||||
|
||||
if args.int4_cpu:
|
||||
processes.append(
|
||||
Process(
|
||||
target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["int4_cpu"])
|
||||
)
|
||||
)
|
||||
|
||||
if args.fp32_gpu:
|
||||
processes.append(
|
||||
Process(
|
||||
target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["fp32_gpu"])
|
||||
)
|
||||
)
|
||||
|
||||
if args.fp16_gpu:
|
||||
processes.append(
|
||||
Process(
|
||||
target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["fp16_gpu"])
|
||||
)
|
||||
)
|
||||
|
||||
if args.int4_gpu:
|
||||
processes.append(
|
||||
Process(
|
||||
target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["int4_gpu"])
|
||||
)
|
||||
)
|
||||
|
||||
if args.fp16_gpu_sm8x:
|
||||
processes.append(
|
||||
Process(
|
||||
target=run_optimize_phi2_onnx,
|
||||
args=(converter, original_onnx_path, *model_type_to_args["fp16_gpu_sm8x"]),
|
||||
)
|
||||
)
|
||||
|
||||
if args.int4_gpu_sm8x:
|
||||
processes.append(
|
||||
Process(
|
||||
target=run_optimize_phi2_onnx,
|
||||
args=(converter, original_onnx_path, *model_type_to_args["int4_gpu_sm8x"]),
|
||||
)
|
||||
)
|
||||
|
||||
[p.start() for p in processes]
|
||||
[p.join() for p in processes]
|
||||
|
||||
if args.run_example:
|
||||
from inference_example import run_phi2
|
||||
|
||||
if args.fp16_gpu_sm8x:
|
||||
logging.info("Running fp16_gpu_sm8x example...")
|
||||
run_phi2(
|
||||
onnx_model_path=model_type_to_args["fp16_gpu_sm8x"][2],
|
||||
use_buffer_share=True,
|
||||
device_id=args.device_id,
|
||||
use_step=True,
|
||||
)
|
||||
if args.int4_gpu_sm8x:
|
||||
logging.info("Running int4_gpu_sm8x example...")
|
||||
run_phi2(
|
||||
onnx_model_path=model_type_to_args["int4_gpu_sm8x"][2],
|
||||
use_buffer_share=True,
|
||||
device_id=args.device_id,
|
||||
use_step=True,
|
||||
)
|
||||
if args.fp32_gpu:
|
||||
logging.info("Running fp32_gpu example...")
|
||||
run_phi2(
|
||||
onnx_model_path=model_type_to_args["fp32_gpu"][2],
|
||||
use_buffer_share=False,
|
||||
device_id=args.device_id,
|
||||
packed_kv=True,
|
||||
use_fp16=False,
|
||||
)
|
||||
if args.fp16_gpu:
|
||||
logging.info("Running fp16_gpu example...")
|
||||
run_phi2(
|
||||
onnx_model_path=model_type_to_args["fp16_gpu"][2],
|
||||
use_buffer_share=False,
|
||||
device_id=args.device_id,
|
||||
packed_kv=True,
|
||||
)
|
||||
if args.int4_gpu:
|
||||
logging.info("Running int4_gpu example...")
|
||||
run_phi2(
|
||||
onnx_model_path=model_type_to_args["int4_gpu"][2],
|
||||
use_buffer_share=False,
|
||||
device_id=args.device_id,
|
||||
packed_kv=True,
|
||||
)
|
||||
if args.fp32_cpu or args.int4_cpu:
|
||||
raise NotImplementedError("CPU inference example is not implemented yet.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,215 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
import onnxruntime as ort
|
||||
|
||||
pt_to_np = {
|
||||
"torch.int32": np.int32,
|
||||
"torch.int64": np.int64,
|
||||
"torch.float32": np.float32,
|
||||
"torch.float16": np.float16,
|
||||
}
|
||||
|
||||
|
||||
class ORTGenerator:
|
||||
def __init__(self, decoder_path):
|
||||
self.onnx_decoder_path = decoder_path
|
||||
self.num_heads = 32
|
||||
self.head_size = 80
|
||||
self.num_layers = 32
|
||||
self.max_sequence_length = 2048
|
||||
|
||||
def get_initial_inputs_and_outputs(self, encodings_dict):
|
||||
self.torch_dtype = torch.float16 if self.use_fp16 else torch.float32
|
||||
|
||||
input_ids = torch.tensor(encodings_dict["input_ids"], device=self.device, dtype=torch.int32)
|
||||
attention_mask = torch.tensor(encodings_dict["attention_mask"], device=self.device, dtype=torch.int32)
|
||||
step = torch.tensor([0], device=self.device, dtype=torch.int64)
|
||||
|
||||
inputs = {
|
||||
"input_ids": input_ids.contiguous(),
|
||||
"attention_mask": attention_mask.contiguous(),
|
||||
}
|
||||
|
||||
if self.use_step:
|
||||
inputs["step"] = step.contiguous()
|
||||
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
|
||||
past_seq_length = self.max_sequence_length if self.use_buffer_share else 0
|
||||
past_shape = (
|
||||
(2, batch_size, self.num_heads, past_seq_length, self.head_size)
|
||||
if self.packed_kv
|
||||
else (batch_size, self.num_heads, past_seq_length, self.head_size)
|
||||
)
|
||||
for i in range(self.num_layers):
|
||||
past = torch.zeros(past_shape, device=self.device, dtype=self.torch_dtype)
|
||||
inputs.update(
|
||||
{f"past_key_{i}": past.contiguous(), f"past_value_{i}": past.clone().contiguous()}
|
||||
) if not self.packed_kv else inputs.update({f"past_{i}": past.contiguous()})
|
||||
|
||||
logits = torch.zeros(batch_size, sequence_length, 51200, device=self.device, dtype=self.torch_dtype)
|
||||
outputs = {"logits": logits.contiguous()}
|
||||
|
||||
if not self.use_buffer_share:
|
||||
present_shape = (
|
||||
(2, batch_size, self.num_heads, sequence_length, self.head_size)
|
||||
if self.packed_kv
|
||||
else (batch_size, self.num_heads, sequence_length, self.head_size)
|
||||
)
|
||||
for i in range(self.num_layers):
|
||||
present = torch.zeros(present_shape, device=self.device, dtype=self.torch_dtype)
|
||||
outputs.update(
|
||||
{f"present_key_{i}": present.contiguous(), f"present_value_{i}": present.contiguous()}
|
||||
) if not self.packed_kv else outputs.update({f"present_{i}": present.contiguous()})
|
||||
|
||||
return inputs, outputs
|
||||
|
||||
def apply_io_binding(self, model: ort.InferenceSession, inputs: dict, outputs: dict):
|
||||
io_binding = model.io_binding()
|
||||
device = None
|
||||
|
||||
for k, v in inputs.items():
|
||||
io_binding.bind_input(
|
||||
name=k,
|
||||
device_type=v.device.type,
|
||||
device_id=0 if v.device.type == "cpu" else v.device.index,
|
||||
element_type=pt_to_np[repr(v.dtype)],
|
||||
shape=tuple(v.shape),
|
||||
buffer_ptr=v.data_ptr(),
|
||||
)
|
||||
device = v.device
|
||||
|
||||
for output in model.get_outputs():
|
||||
name = output.name
|
||||
if self.use_buffer_share and "present" in name:
|
||||
v = inputs[name.replace("present", "past")]
|
||||
io_binding.bind_output(
|
||||
name=name,
|
||||
device_type=v.device.type,
|
||||
device_id=v.device.index,
|
||||
element_type=(np.float16 if self.use_fp16 else np.float32),
|
||||
shape=tuple(v.shape),
|
||||
buffer_ptr=v.data_ptr(),
|
||||
)
|
||||
else:
|
||||
v = outputs[name]
|
||||
io_binding.bind_output(
|
||||
name=name,
|
||||
device_type=device.type,
|
||||
device_id=0 if device.type == "cpu" else device.index,
|
||||
element_type=(np.float16 if self.use_fp16 else np.float32),
|
||||
shape=tuple(v.shape),
|
||||
buffer_ptr=v.data_ptr(),
|
||||
)
|
||||
|
||||
return io_binding
|
||||
|
||||
def create_session(self, device_id, use_fp16=True, use_buffer_share=True, packed_kv=False, use_step=False):
|
||||
sess_options = ort.SessionOptions()
|
||||
ep = ("CUDAExecutionProvider", {"device_id": device_id}) if device_id >= 0 else "CPUExecutionProvider"
|
||||
self.sess = ort.InferenceSession(self.onnx_decoder_path, sess_options=sess_options, providers=[ep])
|
||||
|
||||
self.device = torch.device("cuda", device_id) if torch.cuda.is_available() else torch.device("cpu")
|
||||
self.use_fp16 = use_fp16
|
||||
self.use_buffer_share = use_buffer_share
|
||||
self.packed_kv = packed_kv
|
||||
self.use_step = use_step
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
|
||||
self.tokenizer.pad_token = "[PAD]"
|
||||
|
||||
def generate(self, prompt, max_length):
|
||||
encodings_dict = self.tokenizer.batch_encode_plus(prompt, padding=True)
|
||||
|
||||
inputs, outputs = self.get_initial_inputs_and_outputs(encodings_dict)
|
||||
|
||||
all_token_ids = inputs["input_ids"].clone()
|
||||
batch_size, sequence_length = all_token_ids.shape
|
||||
|
||||
current_length = sequence_length
|
||||
has_eos = torch.zeros(batch_size, device=self.device, dtype=torch.bool)
|
||||
|
||||
while current_length < max_length:
|
||||
io_binding = self.apply_io_binding(self.sess, inputs, outputs)
|
||||
|
||||
io_binding.synchronize_inputs()
|
||||
self.sess.run_with_iobinding(io_binding)
|
||||
io_binding.synchronize_outputs()
|
||||
|
||||
# Sample with argmax (greedy search)
|
||||
next_token_logits = outputs["logits"][:, -1, :]
|
||||
next_tokens = torch.argmax(next_token_logits, dim=-1)
|
||||
|
||||
# Check if we previously reached EOS token id or if generated token id is EOS token id
|
||||
has_eos = has_eos | next_tokens == self.tokenizer.eos_token_id
|
||||
|
||||
# Determine which new tokens to add to list of all token ids
|
||||
# Add EOS token ids for batch entries that ended early (ragged batching scenario where some batch entries ended early and some haven't)
|
||||
tokens_to_add = next_tokens.masked_fill(has_eos, self.tokenizer.eos_token_id).reshape([batch_size, 1])
|
||||
all_token_ids = torch.cat([all_token_ids, tokens_to_add], dim=-1)
|
||||
|
||||
# Return early if all batch entries have reached EOS token id
|
||||
if torch.all(has_eos):
|
||||
break
|
||||
|
||||
# Update inputs for next inference run
|
||||
current_length += 1
|
||||
inputs["input_ids"] = tokens_to_add.to(torch.int32)
|
||||
if self.use_step:
|
||||
inputs["step"] = torch.tensor([current_length - 1], device=self.device, dtype=torch.int64)
|
||||
inputs["attention_mask"] = torch.cat([inputs["attention_mask"], (~has_eos).reshape(batch_size, 1)], 1).to(
|
||||
torch.int32
|
||||
)
|
||||
|
||||
# Set logits to zeros for next inference run and re-use memory buffer
|
||||
if outputs["logits"].shape[1] != 1:
|
||||
outputs["logits"] = outputs["logits"][:, :1, :].contiguous()
|
||||
outputs["logits"].zero_()
|
||||
|
||||
if not self.use_buffer_share:
|
||||
for i in range(self.num_layers):
|
||||
if not self.packed_kv:
|
||||
inputs[f"past_key_{i}"] = outputs[f"present_key_{i}"]
|
||||
inputs[f"past_value_{i}"] = outputs[f"present_value_{i}"]
|
||||
else:
|
||||
inputs[f"past_{i}"] = outputs[f"present_{i}"]
|
||||
|
||||
new_sequence_length = inputs["attention_mask"].shape[1]
|
||||
present_shape = (
|
||||
(2, batch_size, self.num_heads, new_sequence_length, self.head_size)
|
||||
if self.packed_kv
|
||||
else (batch_size, self.num_heads, new_sequence_length, self.head_size)
|
||||
)
|
||||
for i in range(self.num_layers):
|
||||
present = torch.zeros(present_shape, device=self.device, dtype=self.torch_dtype)
|
||||
outputs.update(
|
||||
{f"present_key_{i}": present.contiguous(), f"present_value_{i}": present.clone().contiguous()}
|
||||
) if not self.packed_kv else outputs.update({f"present_{i}": present.contiguous()})
|
||||
|
||||
texts = self.tokenizer.batch_decode(all_token_ids, skip_special_tokens=True)
|
||||
return texts
|
||||
|
||||
|
||||
def run_phi2(onnx_model_path, use_buffer_share, device_id, packed_kv=False, use_fp16=True, use_step=False):
|
||||
prompt = [
|
||||
'''```python
|
||||
def print_prime(n):
|
||||
"""
|
||||
Print all primes between 1 and n
|
||||
"""'''
|
||||
]
|
||||
|
||||
generator = ORTGenerator(onnx_model_path)
|
||||
generator.create_session(device_id, use_fp16, use_buffer_share, packed_kv, use_step)
|
||||
texts = generator.generate(prompt, max_length=200)
|
||||
|
||||
for i in range(len(texts)):
|
||||
print("Prompt: ", prompt[i])
|
||||
print("Texts: ", texts[i])
|
|
@ -0,0 +1,3 @@
|
|||
onnx>=1.15.0
|
||||
transformers>=4.36.2
|
||||
onnxscript>=0.1.0.dev20240126
|
|
@ -82,6 +82,10 @@ class OnnxModel:
|
|||
output_name_to_node[output_name] = node
|
||||
return output_name_to_node
|
||||
|
||||
def functions(self):
|
||||
all_functions = [list(self.model.functions)]
|
||||
return all_functions
|
||||
|
||||
def nodes(self):
|
||||
all_nodes = []
|
||||
for graph in self.graphs():
|
||||
|
@ -733,6 +737,7 @@ class OnnxModel:
|
|||
"node_block_list",
|
||||
"force_fp16_initializers",
|
||||
"force_fp16_inputs",
|
||||
"use_bfloat16_as_blocked_nodes_dtype",
|
||||
]
|
||||
if key in kwargs
|
||||
}
|
||||
|
|
|
@ -0,0 +1,839 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
from logging import getLogger
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
from dynamo_onnx_helper import DynamoOnnxHelper
|
||||
from fusion_base import Fusion
|
||||
from fusion_options import AttentionOpType, FusionOptions
|
||||
from fusion_skiplayernorm import FusionBiasSkipLayerNormalization, FusionSkipLayerNormalization
|
||||
from fusion_utils import NumpyHelper
|
||||
from onnx import ModelProto, NodeProto, TensorProto, helper, numpy_helper
|
||||
from onnx_model import OnnxModel
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class ProcessGemmWFunc:
|
||||
def __call__(self, x):
|
||||
return np.transpose(x, (1, 0))
|
||||
|
||||
|
||||
class ProcessMatMulQFunc:
|
||||
def __call__(self, x):
|
||||
return np.transpose(np.split(x, 3, 0)[0], (1, 0))
|
||||
|
||||
|
||||
class ProcessMatMulKFunc:
|
||||
def __call__(self, x):
|
||||
return np.transpose(np.split(x, 3, 0)[1], (1, 0))
|
||||
|
||||
|
||||
class ProcessMatMulVFunc:
|
||||
def __call__(self, x):
|
||||
return np.transpose(np.split(x, 3, 0)[2], (1, 0))
|
||||
|
||||
|
||||
class ProcessBiasQFunc:
|
||||
def __call__(self, x):
|
||||
x = np.split(x, 3, -1)[0]
|
||||
return x
|
||||
|
||||
|
||||
class ProcessBiasKFunc:
|
||||
def __call__(self, x):
|
||||
x = np.split(x, 3, -1)[1]
|
||||
return x
|
||||
|
||||
|
||||
class ProcessBiasVFunc:
|
||||
def __call__(self, x):
|
||||
x = np.split(x, 3, -1)[2]
|
||||
return x
|
||||
|
||||
|
||||
class ProcessRotCacheFunc:
|
||||
def __call__(self, x):
|
||||
# half rotary embedding
|
||||
assert len(x.shape) == 2
|
||||
if x.shape[1] == 32:
|
||||
return x[:, 0:16]
|
||||
return x
|
||||
|
||||
|
||||
# TODO: move to a seperate file
|
||||
class Fission(Fusion):
|
||||
def __init__(
|
||||
self,
|
||||
model: OnnxModel,
|
||||
nodes_to_find: List[str],
|
||||
):
|
||||
super().__init__(model, "DONOTUSE", nodes_to_find)
|
||||
|
||||
def set_attention_op_type(self, attn_op_type: AttentionOpType):
|
||||
self.attn_op_type = attn_op_type
|
||||
|
||||
def get_uname(self, layer_id, name):
|
||||
return name + "_" + str(layer_id)
|
||||
|
||||
def get_io_by_name(self, node, name):
|
||||
for input in node.input:
|
||||
if input == name or input.endswith(name) or input.startswith(name):
|
||||
return input
|
||||
for output in node.output:
|
||||
if output == name or output.endswith(name) or output.startswith(name):
|
||||
return output
|
||||
raise Exception(f"input {name} not found in node {node.name}")
|
||||
|
||||
def process_initializer(self, initializer_name, functor, custom_name=None):
|
||||
i = self.model.get_initializer(initializer_name)
|
||||
i_np_array = NumpyHelper.to_array(i)
|
||||
processed_i_np_array = functor(i_np_array)
|
||||
new_tensor = helper.make_tensor(
|
||||
initializer_name + "_processed" if custom_name is None else custom_name,
|
||||
data_type=TensorProto.FLOAT,
|
||||
dims=processed_i_np_array.shape,
|
||||
vals=processed_i_np_array.flatten().tobytes(),
|
||||
raw=True,
|
||||
)
|
||||
self.model.add_initializer(new_tensor, self.this_graph_name)
|
||||
return new_tensor.name
|
||||
|
||||
def add_fp32_value_info(self, name):
|
||||
new_value_info = self.model.graph().value_info.add()
|
||||
new_value_info.name = name
|
||||
new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT
|
||||
|
||||
def add_int64_value_info(self, name):
|
||||
new_value_info = self.model.graph().value_info.add()
|
||||
new_value_info.name = name
|
||||
new_value_info.type.tensor_type.elem_type = TensorProto.INT64
|
||||
|
||||
def replace_fp32_value_info(self, name, shape):
|
||||
for value_info in self.model.graph().value_info:
|
||||
if value_info.name == name:
|
||||
self.model.graph().value_info.remove(value_info)
|
||||
break
|
||||
new_value_info = helper.make_tensor_value_info(
|
||||
name,
|
||||
elem_type=TensorProto.FLOAT,
|
||||
shape=shape,
|
||||
)
|
||||
self.model.graph().value_info.extend([new_value_info])
|
||||
|
||||
def set_unique_name_and_add_nodes(
|
||||
self, subgraph_nodes: List[NodeProto], layer_id: int, layer_known_edges_names: List[str]
|
||||
):
|
||||
for new_node in subgraph_nodes:
|
||||
for i, name in enumerate(new_node.input):
|
||||
if name == "":
|
||||
continue
|
||||
elif name not in layer_known_edges_names:
|
||||
new_node.input[i] = self.get_uname(layer_id, name)
|
||||
self.add_fp32_value_info(new_node.input[i])
|
||||
for i, name in enumerate(new_node.output):
|
||||
if name == "":
|
||||
continue
|
||||
elif name not in layer_known_edges_names:
|
||||
new_node.output[i] = self.get_uname(layer_id, name)
|
||||
self.add_fp32_value_info(new_node.output[i])
|
||||
new_node.name = self.get_uname(layer_id, new_node.name)
|
||||
self.nodes_to_add.append(new_node)
|
||||
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
|
||||
|
||||
def layernorm(self, inputs: List[str], outputs: List[str], prefix: str = ""):
|
||||
assert len(inputs) == 3
|
||||
assert len(outputs) == 1
|
||||
node = helper.make_node(
|
||||
"LayerNormalization",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
name=prefix + "_LayerNormalization",
|
||||
epsilon=9.999999747378752e-06,
|
||||
)
|
||||
return [node]
|
||||
|
||||
def gemm(self, inputs: List[str], outputs: List[str], prefix: str = ""):
|
||||
assert len(inputs) == 3
|
||||
assert len(outputs) == 1
|
||||
matmul = helper.make_node(
|
||||
"MatMul",
|
||||
inputs=[inputs[0], inputs[1]],
|
||||
outputs=[prefix + "matmul_out"],
|
||||
name=prefix + "MatMul",
|
||||
)
|
||||
add = helper.make_node(
|
||||
"Add",
|
||||
inputs=[prefix + "matmul_out", inputs[2]],
|
||||
outputs=outputs,
|
||||
name=prefix + "Bias",
|
||||
)
|
||||
return [matmul, add]
|
||||
|
||||
def rotary(self, inputs: List[str], outputs: List[str], prefix: str = "", rot_dim=32, num_heads=32):
|
||||
assert len(inputs) == 4
|
||||
assert len(outputs) == 1
|
||||
node = helper.make_node(
|
||||
"RotaryEmbedding",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
name=prefix + "RotaryEmbedding",
|
||||
domain="com.microsoft",
|
||||
rotary_embedding_dim=rot_dim,
|
||||
num_heads=num_heads,
|
||||
)
|
||||
return [node]
|
||||
|
||||
def fastgelu(self, inputs: List[str], outputs: List[str], prefix: str = ""):
|
||||
assert len(inputs) == 1
|
||||
assert len(outputs) == 1
|
||||
node = helper.make_node(
|
||||
"FastGelu",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
name=prefix + "FastGelu",
|
||||
domain="com.microsoft",
|
||||
)
|
||||
return [node]
|
||||
|
||||
def add(self, inputs: List[str], outputs: List[str], prefix: str = ""):
|
||||
assert len(inputs) == 2
|
||||
assert len(outputs) == 1
|
||||
node = helper.make_node(
|
||||
"Add",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
name=prefix + "Add",
|
||||
)
|
||||
return [node]
|
||||
|
||||
def mha(self, inputs: List[str], outputs: List[str], prefix: str = "", num_heads=32):
|
||||
assert len(inputs) == 8
|
||||
assert len(outputs) == 3
|
||||
node = helper.make_node(
|
||||
"MultiHeadAttention",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
name=prefix + "MultiHeadAttention",
|
||||
domain="com.microsoft",
|
||||
num_heads=num_heads,
|
||||
unidirectional=1,
|
||||
)
|
||||
return [node]
|
||||
|
||||
def gqa(self, inputs: List[str], outputs: List[str], prefix: str = "", num_heads=32):
|
||||
assert len(inputs) == 7
|
||||
assert len(outputs) == 3
|
||||
node = helper.make_node(
|
||||
"GroupQueryAttention",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
name=prefix + "GroupQueryAttention",
|
||||
domain="com.microsoft",
|
||||
num_heads=num_heads,
|
||||
kv_num_heads=num_heads,
|
||||
)
|
||||
return [node]
|
||||
|
||||
def attention(self, inputs: List[str], outputs: List[str], prefix: str = "", num_heads=32):
|
||||
assert len(inputs) == 5
|
||||
assert len(outputs) == 2
|
||||
node = helper.make_node(
|
||||
"Attention",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
name=prefix + "Attention",
|
||||
domain="com.microsoft",
|
||||
num_heads=num_heads,
|
||||
unidirectional=1,
|
||||
do_rotary=1,
|
||||
rotary_embedding_dim=32,
|
||||
)
|
||||
return [node]
|
||||
|
||||
|
||||
class Phi2PreProcessor(DynamoOnnxHelper):
|
||||
def __init__(self, model: ModelProto, num_heads: int, hidden_size: int):
|
||||
super().__init__(model)
|
||||
self.num_hidden_layers = 32
|
||||
self.num_attention_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
self.phi2_edge_dict = self.get_phi2_edge_dict()
|
||||
self.func_name = "modeling_phi_PhiModel_model_1"
|
||||
|
||||
def get_phi2_edge_dict(self) -> dict:
|
||||
edge_dict = {}
|
||||
edge_dict["lm_head_1"] = "logits"
|
||||
edge_dict["l_input_ids_"] = "input_ids"
|
||||
edge_dict["key_states"] = "past_key_0"
|
||||
edge_dict["value_states"] = "past_value_0"
|
||||
for i in range(self.num_hidden_layers):
|
||||
edge_dict[f"key_states_{i}"] = f"past_key_{i}"
|
||||
edge_dict[f"value_states_{i}"] = f"past_value_{i}"
|
||||
edge_dict[f"model_layers_{i}_1"] = f"present_key_{i}"
|
||||
edge_dict[f"model_layers_{i}_1_1"] = f"present_value_{i}"
|
||||
return edge_dict
|
||||
|
||||
def simplify_phi2_op_type(self):
|
||||
phi2_transformer_layer_name = "modeling_phi_PhiDecoderLayer_model_layers"
|
||||
for node in self.model.graph.node:
|
||||
index = node.op_type.find(phi2_transformer_layer_name)
|
||||
if index != -1:
|
||||
node.op_type = node.op_type[index:]
|
||||
|
||||
def process_graph_io(self, attn_op_type: AttentionOpType):
|
||||
self.use_attn = attn_op_type == AttentionOpType.Attention
|
||||
graph = self.model.graph
|
||||
new_inputs = []
|
||||
for vi in graph.input:
|
||||
if "input_ids" in vi.name:
|
||||
vi_iid = helper.make_tensor_value_info(
|
||||
vi.name,
|
||||
elem_type=TensorProto.INT32,
|
||||
shape=["batch_size", "seq_len"],
|
||||
)
|
||||
vi_pid = helper.make_tensor_value_info(
|
||||
"step",
|
||||
elem_type=TensorProto.INT64,
|
||||
shape=[1],
|
||||
)
|
||||
vi_mask = helper.make_tensor_value_info(
|
||||
"attention_mask",
|
||||
elem_type=TensorProto.INT32,
|
||||
shape=["batch_size", "seq_len"],
|
||||
)
|
||||
new_inputs.extend([vi_iid, vi_pid, vi_mask])
|
||||
if not self.use_attn:
|
||||
if "past_key" in vi.name or "past_value" in vi.name:
|
||||
vi_cache = helper.make_tensor_value_info(
|
||||
vi.name,
|
||||
elem_type=vi.type.tensor_type.elem_type,
|
||||
shape=[
|
||||
"batch_size",
|
||||
self.num_attention_heads,
|
||||
"past_seq_len",
|
||||
self.hidden_size // self.num_attention_heads,
|
||||
],
|
||||
)
|
||||
new_inputs.extend([vi_cache])
|
||||
else:
|
||||
if "past_key" in vi.name:
|
||||
vi_cache = helper.make_tensor_value_info(
|
||||
vi.name.replace("past_key", "past"),
|
||||
elem_type=vi.type.tensor_type.elem_type,
|
||||
shape=[
|
||||
2,
|
||||
"batch_size",
|
||||
self.num_attention_heads,
|
||||
"past_seq_len",
|
||||
self.hidden_size // self.num_attention_heads,
|
||||
],
|
||||
)
|
||||
new_inputs.extend([vi_cache])
|
||||
|
||||
graph.ClearField("input")
|
||||
graph.input.extend(new_inputs)
|
||||
|
||||
new_outputs = []
|
||||
for i, vi in enumerate(graph.output):
|
||||
if i == 0:
|
||||
new_outputs.extend([vi])
|
||||
else:
|
||||
if not self.use_attn:
|
||||
vi_cache = helper.make_tensor_value_info(
|
||||
vi.name,
|
||||
elem_type=vi.type.tensor_type.elem_type,
|
||||
shape=[
|
||||
"batch_size",
|
||||
self.num_attention_heads,
|
||||
"total_seq_len",
|
||||
self.hidden_size // self.num_attention_heads,
|
||||
],
|
||||
)
|
||||
new_outputs.extend([vi_cache])
|
||||
else:
|
||||
if "present_key" in vi.name:
|
||||
vi_cache = helper.make_tensor_value_info(
|
||||
vi.name.replace("present_key", "present"),
|
||||
elem_type=vi.type.tensor_type.elem_type,
|
||||
shape=[
|
||||
2,
|
||||
"batch_size",
|
||||
self.num_attention_heads,
|
||||
"total_seq_len",
|
||||
self.hidden_size // self.num_attention_heads,
|
||||
],
|
||||
)
|
||||
new_outputs.extend([vi_cache])
|
||||
|
||||
graph.ClearField("output")
|
||||
graph.output.extend(new_outputs)
|
||||
|
||||
def preprocess_onnx(self, attn_op_type: AttentionOpType):
|
||||
function_name = None
|
||||
for func in self.model.functions:
|
||||
if func.name.endswith(self.func_name):
|
||||
function_name = func.name
|
||||
break
|
||||
assert function_name is not None
|
||||
self.unroll_function(function_name)
|
||||
self.update_edges(self.phi2_edge_dict)
|
||||
self.simplify_phi2_op_type()
|
||||
self.remove_dropout_layer()
|
||||
self.process_graph_io(attn_op_type)
|
||||
|
||||
|
||||
class FissionTransformerEmbeddingPhi(Fission):
|
||||
def __init__(
|
||||
self,
|
||||
model: OnnxModel,
|
||||
):
|
||||
super().__init__(model, ["torch_nn_modules_sparse_Embedding_model_embed_tokens_1"])
|
||||
|
||||
def fuse(self, node, input_name_to_nodes, output_name_to_node):
|
||||
logger.info("Optimizing %s...", node.name)
|
||||
|
||||
assert len(node.input) == 2
|
||||
assert len(node.output) == 1
|
||||
|
||||
input = node.input[0]
|
||||
output = node.output[0]
|
||||
|
||||
embedding = self.get_io_by_name(node, "embed_tokens.weight")
|
||||
|
||||
layer_known_edges_names = [input, output, embedding]
|
||||
|
||||
subgraph_nodes = [
|
||||
helper.make_node(
|
||||
"Gather",
|
||||
inputs=[embedding, input],
|
||||
outputs=[output],
|
||||
name="Embedding_Gather",
|
||||
),
|
||||
]
|
||||
|
||||
self.set_unique_name_and_add_nodes(subgraph_nodes, 0, layer_known_edges_names)
|
||||
self.nodes_to_remove.append(node)
|
||||
self.prune_graph = True
|
||||
|
||||
|
||||
class FissionTransformerLayerNormPhi(Fission):
|
||||
def __init__(
|
||||
self,
|
||||
model: OnnxModel,
|
||||
):
|
||||
super().__init__(model, ["torch_nn_modules_normalization_LayerNorm_model_final_layernorm_1"])
|
||||
|
||||
def fuse(self, node, input_name_to_nodes, output_name_to_node):
|
||||
logger.info("Optimizing %s...", node.name)
|
||||
|
||||
assert len(node.input) == 3
|
||||
assert len(node.output) == 1
|
||||
|
||||
input = node.input[0]
|
||||
output = node.output[0]
|
||||
|
||||
ln_weight = self.get_io_by_name(node, "final_layernorm.weight")
|
||||
ln_bias = self.get_io_by_name(node, "final_layernorm.bias")
|
||||
|
||||
layer_known_edges_names = [input, output, ln_weight, ln_bias]
|
||||
|
||||
subgraph_nodes = []
|
||||
subgraph_nodes.extend(self.layernorm([input, ln_weight, ln_bias], [output], "Final"))
|
||||
|
||||
self.set_unique_name_and_add_nodes(subgraph_nodes, 99, layer_known_edges_names)
|
||||
|
||||
self.replace_fp32_value_info(input, ["batch_size", "seq_len", "hidden_size"])
|
||||
self.replace_fp32_value_info(output, ["batch_size", "seq_len", "hidden_size"])
|
||||
|
||||
self.nodes_to_remove.append(node)
|
||||
self.prune_graph = True
|
||||
|
||||
|
||||
class FissionTransformerCausalLMHeadPhi(Fission):
|
||||
def __init__(
|
||||
self,
|
||||
model: OnnxModel,
|
||||
):
|
||||
super().__init__(model, ["torch_nn_modules_linear_Linear_lm_head_1"])
|
||||
|
||||
def fuse(self, node, input_name_to_nodes, output_name_to_node):
|
||||
logger.info("Optimizing %s...", node.name)
|
||||
|
||||
assert len(node.input) == 5
|
||||
assert len(node.output) == 1
|
||||
|
||||
input = node.input[2]
|
||||
output = node.output[0]
|
||||
|
||||
fc_weight = self.process_initializer(self.get_io_by_name(node, "lm_head.weight"), ProcessGemmWFunc())
|
||||
fc_bias = self.get_io_by_name(node, "lm_head.bias")
|
||||
|
||||
layer_known_edges_names = [input, output, fc_weight, fc_bias]
|
||||
|
||||
subgraph_nodes = []
|
||||
subgraph_nodes.extend(self.gemm([input, fc_weight, fc_bias], [output], "LMHead_"))
|
||||
|
||||
self.set_unique_name_and_add_nodes(subgraph_nodes, 99, layer_known_edges_names)
|
||||
|
||||
self.replace_fp32_value_info(input, ["batch_size", "seq_len", "hidden_size"])
|
||||
self.replace_fp32_value_info(output, ["batch_size", "seq_len", 51200])
|
||||
|
||||
self.nodes_to_remove.append(node)
|
||||
self.prune_graph = True
|
||||
|
||||
|
||||
class FissionTransformerBlockPhi(Fission):
|
||||
def __init__(
|
||||
self,
|
||||
model: OnnxModel,
|
||||
num_heads: int,
|
||||
):
|
||||
self.num_heads = num_heads
|
||||
max_num_layers = 32
|
||||
self.func_to_layer_id = {}
|
||||
nodes_to_find = []
|
||||
for layer in range(max_num_layers):
|
||||
func_name = f"modeling_phi_PhiDecoderLayer_model_layers_{layer}_1"
|
||||
nodes_to_find.append(func_name)
|
||||
self.func_to_layer_id[func_name] = layer
|
||||
|
||||
super().__init__(model, nodes_to_find)
|
||||
|
||||
def get_layer_id(self, node):
|
||||
return self.func_to_layer_id[node.op_type]
|
||||
|
||||
def get_gqa_aux_nodes(self):
|
||||
gqa_aux_nodes = [
|
||||
helper.make_node(
|
||||
"Cast",
|
||||
inputs=["attention_mask"],
|
||||
outputs=["mask_int64"],
|
||||
name="Cast_gqa_aux_0",
|
||||
to=TensorProto.INT64,
|
||||
),
|
||||
helper.make_node(
|
||||
"ReduceSum",
|
||||
inputs=["mask_int64", "one"],
|
||||
outputs=["mask_row_sums"],
|
||||
name="ReduceSum_gqa_aux",
|
||||
),
|
||||
helper.make_node(
|
||||
"Sub",
|
||||
inputs=["mask_row_sums", "one"],
|
||||
outputs=["seqlens_k_int64"],
|
||||
name="Sub_gqa_aux",
|
||||
),
|
||||
helper.make_node(
|
||||
"Cast",
|
||||
inputs=["seqlens_k_int64"],
|
||||
outputs=["seqlens_k"],
|
||||
name="Cast_gqa_aux_1",
|
||||
to=TensorProto.INT32,
|
||||
),
|
||||
helper.make_node("Shape", inputs=["mask_int64"], outputs=["mask_shape"], name="Shape_gqa_aux_0"),
|
||||
helper.make_node(
|
||||
"Gather",
|
||||
inputs=["mask_shape", "one"],
|
||||
outputs=["total_seq_len_int64"],
|
||||
name="Gather_gqa_aux_0",
|
||||
axis=0,
|
||||
),
|
||||
helper.make_node(
|
||||
"Cast",
|
||||
inputs=["total_seq_len_int64"],
|
||||
outputs=["total_sequence_length"],
|
||||
name="Cast_gqa_aux_2",
|
||||
to=TensorProto.INT32,
|
||||
),
|
||||
]
|
||||
return gqa_aux_nodes
|
||||
|
||||
def pack_qkv_gemm(self, q_w, k_w, v_w, q_b, k_b, v_b, weight_name, bias_name):
|
||||
q_weight = self.model.get_initializer(q_w)
|
||||
k_weight = self.model.get_initializer(k_w)
|
||||
v_weight = self.model.get_initializer(v_w)
|
||||
qw = np.transpose(NumpyHelper.to_array(q_weight), (1, 0))
|
||||
kw = np.transpose(NumpyHelper.to_array(k_weight), (1, 0))
|
||||
vw = np.transpose(NumpyHelper.to_array(v_weight), (1, 0))
|
||||
qkv_weight = np.stack((qw, kw, vw), axis=1)
|
||||
|
||||
q_bias = self.model.get_initializer(q_b)
|
||||
k_bias = self.model.get_initializer(k_b)
|
||||
v_bias = self.model.get_initializer(v_b)
|
||||
qb = NumpyHelper.to_array(q_bias)
|
||||
kb = NumpyHelper.to_array(k_bias)
|
||||
vb = NumpyHelper.to_array(v_bias)
|
||||
qkv_bias = np.stack((qb, kb, vb), axis=0)
|
||||
|
||||
hidden_size = qkv_weight.shape[0]
|
||||
|
||||
weight = helper.make_tensor(
|
||||
weight_name,
|
||||
data_type=TensorProto.FLOAT,
|
||||
dims=[hidden_size, hidden_size * 3],
|
||||
vals=qkv_weight.flatten().tobytes(),
|
||||
raw=True,
|
||||
)
|
||||
self.model.add_initializer(weight, self.this_graph_name)
|
||||
|
||||
bias = helper.make_tensor(
|
||||
bias_name,
|
||||
data_type=TensorProto.FLOAT,
|
||||
dims=[hidden_size * 3],
|
||||
vals=qkv_bias.flatten().tobytes(),
|
||||
raw=True,
|
||||
)
|
||||
self.model.add_initializer(bias, self.this_graph_name)
|
||||
|
||||
self.add_fp32_value_info(weight.name)
|
||||
self.add_fp32_value_info(bias.name)
|
||||
|
||||
return weight_name, bias_name
|
||||
|
||||
def fuse(
|
||||
self,
|
||||
node,
|
||||
input_name_to_nodes,
|
||||
output_name_to_node,
|
||||
):
|
||||
logger.info("Optimizing %s...", node.name)
|
||||
|
||||
logger.info(f"AttentionOpType: {self.attn_op_type}")
|
||||
|
||||
layer_id = self.get_layer_id(node)
|
||||
|
||||
i_hidden_states = node.input[0]
|
||||
i_key_cache = self.get_io_by_name(node, "past_key")
|
||||
i_value_cache = self.get_io_by_name(node, "past_value")
|
||||
|
||||
o_hidden_states = node.output[3]
|
||||
o_key_cache = self.get_io_by_name(node, "present_key")
|
||||
o_value_cache = self.get_io_by_name(node, "present_value")
|
||||
|
||||
ln_weight = self.get_io_by_name(node, "input_layernorm.weight")
|
||||
ln_bias = self.get_io_by_name(node, "input_layernorm.bias")
|
||||
|
||||
attn_q_weight, attn_q_bias, attn_k_weight, attn_k_bias, attn_v_weight, attn_v_bias = (
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
attn_qkv_weight, attn_qkv_bias = None, None
|
||||
cos_cache, sin_cache = None, None
|
||||
|
||||
if self.attn_op_type != AttentionOpType.Attention:
|
||||
attn_q_weight = self.process_initializer(
|
||||
self.get_io_by_name(node, "self_attn.q_proj.weight"), ProcessGemmWFunc()
|
||||
)
|
||||
attn_k_weight = self.process_initializer(
|
||||
self.get_io_by_name(node, "self_attn.k_proj.weight"), ProcessGemmWFunc()
|
||||
)
|
||||
attn_v_weight = self.process_initializer(
|
||||
self.get_io_by_name(node, "self_attn.v_proj.weight"), ProcessGemmWFunc()
|
||||
)
|
||||
attn_q_bias = self.get_io_by_name(node, "self_attn.q_proj.bias")
|
||||
attn_k_bias = self.get_io_by_name(node, "self_attn.k_proj.bias")
|
||||
attn_v_bias = self.get_io_by_name(node, "self_attn.v_proj.bias")
|
||||
|
||||
cos_cache = self.process_initializer(
|
||||
self.get_io_by_name(node, "rotary_emb.cos_cached"), ProcessRotCacheFunc()
|
||||
)
|
||||
sin_cache = self.process_initializer(
|
||||
self.get_io_by_name(node, "rotary_emb.sin_cached"), ProcessRotCacheFunc()
|
||||
)
|
||||
else:
|
||||
attn_qkv_weight, attn_qkv_bias = self.pack_qkv_gemm(
|
||||
self.get_io_by_name(node, "self_attn.q_proj.weight"),
|
||||
self.get_io_by_name(node, "self_attn.k_proj.weight"),
|
||||
self.get_io_by_name(node, "self_attn.v_proj.weight"),
|
||||
self.get_io_by_name(node, "self_attn.q_proj.bias"),
|
||||
self.get_io_by_name(node, "self_attn.k_proj.bias"),
|
||||
self.get_io_by_name(node, "self_attn.v_proj.bias"),
|
||||
self.get_uname(layer_id, "attn_qkv_weight"),
|
||||
self.get_uname(layer_id, "attn_qkv_bias"),
|
||||
)
|
||||
|
||||
attn_out_weight = self.process_initializer(
|
||||
self.get_io_by_name(node, "self_attn.dense.weight"), ProcessGemmWFunc()
|
||||
)
|
||||
attn_out_bias = self.get_io_by_name(node, "self_attn.dense.bias")
|
||||
|
||||
mlp_fc1_weight = self.process_initializer(self.get_io_by_name(node, "mlp.fc1.weight"), ProcessGemmWFunc())
|
||||
mlp_fc2_weight = self.process_initializer(self.get_io_by_name(node, "mlp.fc2.weight"), ProcessGemmWFunc())
|
||||
mlp_fc1_bias = self.get_io_by_name(node, "mlp.fc1.bias")
|
||||
mlp_fc2_bias = self.get_io_by_name(node, "mlp.fc2.bias")
|
||||
|
||||
layer_known_edges_names = []
|
||||
layer_known_edges_names.extend([i_hidden_states, i_key_cache, i_value_cache])
|
||||
layer_known_edges_names.extend([o_hidden_states, o_key_cache, o_value_cache])
|
||||
layer_known_edges_names.extend([ln_weight, ln_bias])
|
||||
if self.attn_op_type != AttentionOpType.Attention:
|
||||
layer_known_edges_names.extend(
|
||||
[
|
||||
attn_q_weight,
|
||||
attn_q_bias,
|
||||
attn_k_weight,
|
||||
attn_k_bias,
|
||||
attn_v_weight,
|
||||
attn_v_bias,
|
||||
cos_cache,
|
||||
sin_cache,
|
||||
]
|
||||
)
|
||||
else:
|
||||
layer_known_edges_names.extend([attn_qkv_weight, attn_qkv_bias])
|
||||
layer_known_edges_names.extend(
|
||||
[attn_out_weight, attn_out_bias, mlp_fc1_weight, mlp_fc1_bias, mlp_fc2_weight, mlp_fc2_bias]
|
||||
)
|
||||
layer_known_edges_names.extend(["attention_mask", "step", "seqlens_k", "total_sequence_length"])
|
||||
|
||||
subgraph_nodes = []
|
||||
subgraph_nodes.extend(self.layernorm([i_hidden_states, ln_weight, ln_bias], ["ln_out"]))
|
||||
subgraph_nodes.extend(self.gemm(["attn_out", attn_out_weight, attn_out_bias], ["attn_add_out"], "OutProj_"))
|
||||
subgraph_nodes.extend(self.gemm(["ln_out", mlp_fc1_weight, mlp_fc1_bias], ["fc1_out"], "FC1_"))
|
||||
subgraph_nodes.extend(self.fastgelu(["fc1_out"], ["gelu_out"]))
|
||||
subgraph_nodes.extend(self.gemm(["gelu_out", mlp_fc2_weight, mlp_fc2_bias], ["fc2_out"], "FC2_"))
|
||||
subgraph_nodes.extend(self.add(["attn_add_out", "fc2_out"], ["residual_1_out"], "Residual_1"))
|
||||
subgraph_nodes.extend(self.add([i_hidden_states, "residual_1_out"], [o_hidden_states], "Residual_2"))
|
||||
if self.attn_op_type != AttentionOpType.Attention:
|
||||
subgraph_nodes.extend(self.gemm(["ln_out", attn_q_weight, attn_q_bias], ["query"], "Q_"))
|
||||
subgraph_nodes.extend(self.gemm(["ln_out", attn_k_weight, attn_k_bias], ["key"], "K_"))
|
||||
subgraph_nodes.extend(self.gemm(["ln_out", attn_v_weight, attn_v_bias], ["value"], "V_"))
|
||||
subgraph_nodes.extend(self.rotary(["query", "step", cos_cache, sin_cache], ["query_rot"], "Q_"))
|
||||
subgraph_nodes.extend(self.rotary(["key", "step", cos_cache, sin_cache], ["key_rot"], "K_"))
|
||||
if self.attn_op_type == AttentionOpType.MultiHeadAttention:
|
||||
subgraph_nodes.extend(
|
||||
self.mha(
|
||||
["query_rot", "key_rot", "value", "", "attention_mask", "", i_key_cache, i_value_cache],
|
||||
["attn_out", o_key_cache, o_value_cache],
|
||||
)
|
||||
)
|
||||
elif self.attn_op_type == AttentionOpType.GroupQueryAttention:
|
||||
subgraph_nodes.extend(
|
||||
self.gqa(
|
||||
[
|
||||
"query_rot",
|
||||
"key_rot",
|
||||
"value",
|
||||
i_key_cache,
|
||||
i_value_cache,
|
||||
"seqlens_k",
|
||||
"total_sequence_length",
|
||||
],
|
||||
["attn_out", o_key_cache, o_value_cache],
|
||||
)
|
||||
)
|
||||
if layer_id == 0:
|
||||
gqa_aux_nodes = self.get_gqa_aux_nodes()
|
||||
for new_node in gqa_aux_nodes:
|
||||
self.nodes_to_add.append(new_node)
|
||||
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
|
||||
self.model.add_initializer(
|
||||
numpy_helper.from_array(np.array([1], dtype="int64"), name="one"), self.this_graph_name
|
||||
)
|
||||
else:
|
||||
past_name = f"past_{layer_id}"
|
||||
present_name = f"present_{layer_id}"
|
||||
layer_known_edges_names.extend([past_name, present_name])
|
||||
subgraph_nodes.extend(
|
||||
self.attention(
|
||||
["ln_out", attn_qkv_weight, attn_qkv_bias, "attention_mask", past_name], ["attn_out", present_name]
|
||||
)
|
||||
)
|
||||
|
||||
self.set_unique_name_and_add_nodes(subgraph_nodes, layer_id, layer_known_edges_names)
|
||||
|
||||
self.replace_fp32_value_info(i_hidden_states, ["batch_size", "seq_len", "hidden_size"])
|
||||
self.replace_fp32_value_info(o_hidden_states, ["batch_size", "seq_len", "hidden_size"])
|
||||
|
||||
self.nodes_to_remove.append(node)
|
||||
self.prune_graph = True
|
||||
|
||||
|
||||
class PhiOnnxModel(OnnxModel):
|
||||
def __init__(self, model: ModelProto, num_heads: int, hidden_size: int):
|
||||
super().__init__(model)
|
||||
self.phi2_preprocessor = Phi2PreProcessor(self.model, num_heads, hidden_size)
|
||||
self.fission_transformer_block = FissionTransformerBlockPhi(self, num_heads)
|
||||
self.fission_causal_lm_head = FissionTransformerCausalLMHeadPhi(self)
|
||||
self.fission_transformer_layernorm = FissionTransformerLayerNormPhi(self)
|
||||
self.fission_transformer_embedding = FissionTransformerEmbeddingPhi(self)
|
||||
|
||||
def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bool = False):
|
||||
assert options is not None
|
||||
attn_op_type = options.attention_op_type
|
||||
|
||||
self.fission_transformer_block.set_attention_op_type(attn_op_type)
|
||||
|
||||
self.phi2_preprocessor.preprocess_onnx(attn_op_type)
|
||||
|
||||
self.fission_transformer_block.apply()
|
||||
self.fission_transformer_layernorm.apply()
|
||||
self.fission_causal_lm_head.apply()
|
||||
self.fission_transformer_embedding.apply()
|
||||
|
||||
super().prune_graph()
|
||||
|
||||
# SLN ctor is placed here intentionally to delay the symbolic shape inference
|
||||
self.fuse_sln = FusionSkipLayerNormalization(self)
|
||||
self.fuse_bias_sln = FusionBiasSkipLayerNormalization(self)
|
||||
self.fuse_sln.apply()
|
||||
self.fuse_bias_sln.apply()
|
||||
|
||||
def get_fused_operator_statistics(self):
|
||||
"""
|
||||
Returns node count of fused operators.
|
||||
"""
|
||||
op_count = {}
|
||||
ops = [
|
||||
"Attention",
|
||||
"MultiHeadAttention",
|
||||
"GroupQueryAttention",
|
||||
"Gelu",
|
||||
"BiasGelu",
|
||||
"FastGelu",
|
||||
"LayerNormalization",
|
||||
"SkipLayerNormalization",
|
||||
]
|
||||
for op in ops:
|
||||
nodes = self.get_nodes_by_op_type(op)
|
||||
op_count[op] = len(nodes)
|
||||
|
||||
logger.info(f"Optimized operators: {op_count}")
|
||||
return op_count
|
||||
|
||||
def is_fully_optimized(self, fused_op_count=None):
|
||||
"""
|
||||
Returns True when the model is fully optimized.
|
||||
"""
|
||||
if fused_op_count is None:
|
||||
fused_op_count = self.get_fused_operator_statistics()
|
||||
|
||||
def op_count(op_name: str):
|
||||
return fused_op_count.get(op_name) or 0
|
||||
|
||||
attention = op_count("Attention") + op_count("MultiHeadAttention") + op_count("GroupQueryAttention")
|
||||
gelu = op_count("Gelu") + op_count("BiasGelu") + op_count("FastGelu")
|
||||
layer_norm = op_count("LayerNormalization") + op_count("SkipLayerNormalization")
|
||||
|
||||
is_perfect = (attention > 0) and (attention == gelu) and (layer_norm >= attention)
|
||||
|
||||
if layer_norm == 0:
|
||||
logger.debug("Layer Normalization not fused")
|
||||
|
||||
if gelu == 0:
|
||||
logger.debug("Gelu (or FastGelu) not fused")
|
||||
|
||||
if attention == 0:
|
||||
logger.warning("Attention (or MultiHeadAttention) not fused")
|
||||
|
||||
return is_perfect
|
|
@ -34,6 +34,7 @@ from onnx_model_bert_tf import BertOnnxModelTF
|
|||
from onnx_model_clip import ClipOnnxModel
|
||||
from onnx_model_conformer import ConformerOnnxModel
|
||||
from onnx_model_gpt2 import Gpt2OnnxModel
|
||||
from onnx_model_phi import PhiOnnxModel
|
||||
from onnx_model_t5 import T5OnnxModel
|
||||
from onnx_model_tnlr import TnlrOnnxModel
|
||||
from onnx_model_unet import UnetOnnxModel
|
||||
|
@ -58,6 +59,7 @@ MODEL_TYPES = {
|
|||
"vae": (VaeOnnxModel, "pytorch", 1), # UAE in Stable Diffusion
|
||||
"vit": (BertOnnxModel, "pytorch", 1),
|
||||
"conformer": (ConformerOnnxModel, "pytorch", 1),
|
||||
"phi": (PhiOnnxModel, "pytorch", 0),
|
||||
}
|
||||
|
||||
|
||||
|
|
1
setup.py
1
setup.py
|
@ -419,6 +419,7 @@ packages = [
|
|||
"onnxruntime.transformers.models.gpt2",
|
||||
"onnxruntime.transformers.models.llama",
|
||||
"onnxruntime.transformers.models.longformer",
|
||||
"onnxruntime.transformers.models.phi2",
|
||||
"onnxruntime.transformers.models.t5",
|
||||
"onnxruntime.transformers.models.stable_diffusion",
|
||||
"onnxruntime.transformers.models.whisper",
|
||||
|
|
Загрузка…
Ссылка в новой задаче