phi2 conversion/optimization script (#19338)

### Description
<!-- Describe your changes. -->
This PR adds 
onnx conversion script for dynamo exported phi2,
optimization script,
and inference example script

A readme file is added as documentation.
https://github.com/microsoft/onnxruntime/tree/wangye/phi2_doc/onnxruntime/python/tools/transformers/models/phi2#readme


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
This commit is contained in:
Ye Wang 2024-02-05 18:15:16 +00:00 коммит произвёл GitHub
Родитель e6d3518db9
Коммит aaf32fb1b1
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
14 изменённых файлов: 1801 добавлений и 3 удалений

Просмотреть файл

@ -473,6 +473,9 @@ file(GLOB onnxruntime_python_transformers_models_llama_src CONFIGURE_DEPENDS
file(GLOB onnxruntime_python_transformers_models_longformer_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/longformer/*.py"
)
file(GLOB onnxruntime_python_transformers_models_phi2_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/phi2/*.py"
)
file(GLOB onnxruntime_python_transformers_models_stable_diffusion_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/stable_diffusion/*.py"
)
@ -543,6 +546,7 @@ add_custom_command(
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/gpt2
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/llama
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/longformer
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/phi2
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/stable_diffusion
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/t5
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/whisper
@ -646,6 +650,9 @@ add_custom_command(
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_models_longformer_src}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/longformer/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_models_phi2_src}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/phi2/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_models_stable_diffusion_src}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/stable_diffusion/

Просмотреть файл

@ -205,6 +205,7 @@ class SymbolicShapeInference:
"GemmFastGelu": self._infer_GemmFastGelu,
"GemmFloat8": self._infer_GemmFloat8,
"GroupNorm": self._infer_GroupNorm,
"GroupQueryAttention": self._infer_GroupQueryAttention,
"SkipGroupNorm": self._infer_SkipGroupNorm,
"LayerNormalization": self._infer_LayerNormalization,
"LongformerAttention": self._infer_LongformerAttention,
@ -471,6 +472,7 @@ class SymbolicShapeInference:
"PythonOp",
"MultiHeadAttention",
"GroupNorm",
"GroupQueryAttention",
"SkipGroupNorm",
"BiasSplitGelu",
"BiasAdd",
@ -2409,6 +2411,32 @@ class SymbolicShapeInference:
def _infer_GroupNorm(self, node): # noqa: N802
self._propagate_shape_and_type(node)
def _infer_GroupQueryAttention(self, node): # noqa: N802
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
past_shape = self._try_get_shape(node, 3)
if past_shape is not None:
vi = self.known_vi_[node.output[1]]
vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape))
vi = self.known_vi_[node.output[2]]
vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape))
if node.input[1] != "" and node.input[2] != "":
self._propagate_shape_and_type(node, 0, 0)
else:
# combined qkv: (batch_size, sequence_length, num_heads * head_size + 2 * kv_num_heads * head_size)
assert node.input[1] == "" and node.input[2] == ""
num_heads = get_attribute(node, "num_heads")
kv_num_heads = get_attribute(node, "kv_num_heads")
query_shape = self._get_shape(node, 0)
if query_shape is not None:
hidden_size = query_shape[2]
if isinstance(hidden_size, int):
head_size = int(hidden_size / (num_heads + 2 * kv_num_heads))
query_shape[2] = num_heads * head_size
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, query_shape))
def _infer_SkipGroupNorm(self, node): # noqa: N802
self._propagate_shape_and_type(node, 0, 0)
if len(node.output) > 1:

Просмотреть файл

@ -0,0 +1,92 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
import onnx
class DynamoOnnxHelper:
"""
Helper class for processing ONNX models exported by torch Dynamo.
"""
def __init__(self, model: onnx.ModelProto):
self.model = model
def update_edges(self, edge_mapping: dict) -> None:
"""
Updates the edges in the model according to the given mapping.
"""
for node in self.model.graph.node:
for i in range(len(node.input)):
if node.input[i] in edge_mapping:
node.input[i] = edge_mapping[node.input[i]]
for i in range(len(node.output)):
if node.output[i] in edge_mapping:
node.output[i] = edge_mapping[node.output[i]]
for graph_input in self.model.graph.input:
if graph_input.name in edge_mapping:
graph_input.name = edge_mapping[graph_input.name]
for graph_output in self.model.graph.output:
if graph_output.name in edge_mapping:
graph_output.name = edge_mapping[graph_output.name]
def unroll_function(self, func_name: str) -> None:
"""
Unrolls the function with the given name in the model.
"""
logging.info(f"Unrolling function {func_name}...")
nodes_to_remove = []
nodes_to_add = []
edges_to_remove = []
edges_to_add = []
for node in self.model.graph.node:
if node.op_type == func_name:
nodes_to_remove.append(node)
edges_to_remove.extend(list(node.input) + list(node.output))
func_to_remove = None
for f in self.model.functions:
if f.name == func_name:
nodes_to_add.extend(list(f.node))
edges_to_add.extend(list(f.input) + list(f.output))
func_to_remove = f
assert len(edges_to_remove) == len(edges_to_add)
for node in nodes_to_remove:
self.model.graph.node.remove(node)
for node in nodes_to_add:
self.model.graph.node.append(node)
if func_to_remove is not None:
self.model.functions.remove(func_to_remove)
edge_mapping = {}
for i in range(len(edges_to_remove)):
k = edges_to_remove[i]
v = edges_to_add[i]
if k != v:
edge_mapping[k] = v
return self.update_edges(edge_mapping)
def remove_dropout_layer(self) -> None:
"""
Removes the dropout layer in the model.
"""
logging.info("Removing dropout layer...")
edge_mapping = {}
nodes_to_remove = []
for node in self.model.graph.node:
if node.op_type.find("Dropout") != -1:
assert len(node.input) == 1
assert len(node.output) == 1
edge_mapping[node.output[0]] = node.input[0]
nodes_to_remove.append(node)
for node in nodes_to_remove:
self.model.graph.node.remove(node)
self.update_edges(edge_mapping)

Просмотреть файл

@ -174,6 +174,7 @@ def convert_float_to_float16(
node_block_list=None,
force_fp16_initializers=False,
force_fp16_inputs=None,
use_bfloat16_as_blocked_nodes_dtype=False,
):
"""Convert tensor float type in the input ONNX model to tensor float16.
@ -436,6 +437,7 @@ def convert_float_to_float16(
node.input[i] = output_name
break
accuracy_type = TensorProto.BFLOAT16 if use_bfloat16_as_blocked_nodes_dtype else TensorProto.FLOAT
# process the nodes in block list that doesn't support tensor(float16)
for node in node_list:
# if input's name is in the value_info_list meaning input is tensor(float16) type,
@ -450,10 +452,10 @@ def convert_float_to_float16(
new_value_info.CopyFrom(value_info)
output_name = node.name + "_input_cast_" + str(i)
new_value_info.name = output_name
new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT
new_value_info.type.tensor_type.elem_type = accuracy_type
# add Cast node (from tensor(float16) to tensor(float) before current node
node_name = node.name + "_input_cast" + str(i)
new_node = [helper.make_node("Cast", [input_name], [output_name], to=1, name=node_name)]
new_node = [helper.make_node("Cast", [input_name], [output_name], to=accuracy_type, name=node_name)]
model.graph.node.extend(new_node)
# change current node's input name
node.input[i] = output_name
@ -469,7 +471,7 @@ def convert_float_to_float16(
new_value_info.CopyFrom(value_info)
input_name = node.name + "_output_cast_" + str(i)
new_value_info.name = input_name
new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT
new_value_info.type.tensor_type.elem_type = accuracy_type
# add Cast node (from tensor(float) to tensor(float16) after current node
node_name = node.name + "_output_cast" + str(i)
new_node = [helper.make_node("Cast", [input_name], [output], to=10, name=node_name)]

Просмотреть файл

@ -3,6 +3,7 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from argparse import ArgumentParser
from enum import Enum
class AttentionMaskFormat:
@ -19,6 +20,15 @@ class AttentionMaskFormat:
NoMask = 3
class AttentionOpType(Enum):
Attention = "Attention"
MultiHeadAttention = "MultiHeadAttention"
GroupQueryAttention = "GroupQueryAttention"
def __str__(self):
return self.value
class FusionOptions:
"""Options of fusion in graph optimization"""
@ -57,6 +67,8 @@ class FusionOptions:
elif model_type == "vit":
self.attention_mask_format = AttentionMaskFormat.NoMask
self.attention_op_type = None
# options for stable diffusion
if model_type in ["unet", "vae", "clip"]:
self.enable_nhwc_conv = True
@ -76,6 +88,9 @@ class FusionOptions:
def disable_attention_mask(self):
self.attention_mask_format = AttentionMaskFormat.NoMask
def set_attention_op_type(self, attn_op_type: AttentionOpType):
self.attention_op_type = attn_op_type
@staticmethod
def parse(args):
options = FusionOptions(args.model_type)

Просмотреть файл

@ -0,0 +1,119 @@
# Phi2 Optimizations
## Prerequisites
A Linux machine for [TorchDynamo-based ONNX Exporter](https://pytorch.org/docs/stable/onnx.html#torchdynamo-based-onnx-exporter)\
Install onnx, onnxscript and transformers by running
```bash
pip install -r requirements.txt
```
To export ONNX, PyTorch version 2.2.0 or higher is required. The [official website](https://pytorch.org/) offers packages compatible with CUDA 11.8 and 12.1. Please select the appropriate version according to your needs.
\
\
**There are two options to run the conversion script:**\
_From source:_
```bash
pip install onnxruntime-gpu==1.17.0 # or onnxruntime==1.17.0 if using cpu
git clone git@github.com:microsoft/onnxruntime.git
cd onnxruntime/onnxruntime/python/tools/transformers
python -m models.phi2.convert_to_onnx -h
```
_From wheel:_ \
Install [ORT nightly package](https://onnxruntime.ai/docs/install/#inference-install-table-for-all-languages)
```bash
python -m onnxruntime.transformers.models.phi2.convert_to_onnx -h
```
## Export optimized phi2 onnx model for different scenarios
**Export FP32 ONNX model for Nvidia GPUs** \
_From source:_
```
python -m models.phi2.convert_to_onnx --fp32_gpu
```
_From wheel:_
```
python -m onnxruntime.transformers.models.phi2.convert_to_onnx --fp32_gpu
```
\
**Export FP16 ONNX model for Nvidia GPUs** \
_From source:_
```
python -m models.phi2.convert_to_onnx --fp16_gpu
```
_From wheel:_
```
python -m onnxruntime.transformers.models.phi2.convert_to_onnx --fp16_gpu
```
\
**Export INT4 ONNX model for Nvidia GPUs** \
_From source:_
```
python -m models.phi2.convert_to_onnx --int4_gpu
```
_From wheel:_
```
python -m onnxruntime.transformers.models.phi2.convert_to_onnx --int4_gpu
```
\
**Export FP16 ONNX model for Nvidia GPUs with CUDA architecture SM=80~89** \
_From source:_
```
python -m models.phi2.convert_to_onnx --fp16_gpu_sm8x
```
_From wheel:_
```
python -m onnxruntime.transformers.models.phi2.convert_to_onnx --fp16_gpu_sm8x
```
\
**Export INT4 ONNX model for Nvidia GPUs with CUDA architecture SM=80~89** \
_From source:_
```
python -m models.phi2.convert_to_onnx --int4_gpu_sm8x
```
_From wheel:_
```
python -m onnxruntime.transformers.models.phi2.convert_to_onnx --int4_gpu_sm8x
```
\
**Export FP32 ONNX model for CPU** \
_From source:_
```
python -m models.phi2.convert_to_onnx --fp32_cpu
```
_From wheel:_
```
python -m onnxruntime.transformers.models.phi2.convert_to_onnx --fp32_cpu
```
\
**Export INT4 ONNX model for CPU** \
_From source:_
```
python -m models.phi2.convert_to_onnx --int4_cpu
```
_From wheel:_
```
python -m onnxruntime.transformers.models.phi2.convert_to_onnx --int4_cpu
```
\
**Export all at once** \
_From source:_
```
python -m models.phi2.convert_to_onnx --fp32_cpu --int4_cpu --fp32_gpu --fp16_gpu --int4_gpu --fp16_gpu_sm8x --int4_gpu_sm8x
```
_From wheel:_
```
python -m onnxruntime.transformers.models.phi2.convert_to_onnx --fp32_cpu --int4_cpu --fp32_gpu --fp16_gpu --int4_gpu --fp16_gpu_sm8x --int4_gpu_sm8x
```
## Run example with ORT
**(e.g) Export FP16 and INT4 ONNX models for Nvidia GPUs with CUDA architecture SM=80~89 and run examples.** \
_From source:_
```
python -m models.phi2.convert_to_onnx --fp16_gpu_sm8x --int4_gpu_sm8x --run_example
```
_From wheel:_
```
python -m onnxruntime.transformers.models.phi2.convert_to_onnx --fp16_gpu_sm8x --int4_gpu_sm8x --run_example
```
The inference example currently supports all models running on CUDA.
## Limitations
- TorchDynamo-based ONNX Exporter only supports Linux.
- The program may not run as expected if the machine has limited memory. e.g Dynamo export may use ~11.6GB; Optimization may use ~4.5GB for each.

Просмотреть файл

@ -0,0 +1,12 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import os
import sys
sys.path.append(os.path.dirname(__file__))
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
if transformers_dir not in sys.path:
sys.path.append(transformers_dir)

Просмотреть файл

@ -0,0 +1,458 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from __future__ import annotations
import argparse
import logging
import os
from pathlib import Path
import onnx
import torch
from benchmark_helper import Precision
from fusion_options import AttentionOpType
from transformers import AutoConfig, AutoModelForCausalLM
from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer
class ConvertPhi2ToONNX:
def __init__(
self,
device: torch.device,
model_class: str = "microsoft/phi-2",
cache_dir: str = "./cache",
):
self.model_class = model_class
self.device = device
self.cache_dir = cache_dir
self.phi_config = AutoConfig.from_pretrained(self.model_class, trust_remote_code=True, cache_dir=self.cache_dir)
self.phi_model = None
self.batch_size = 2
self.sequence_length = 8
self.attn_op_type = None
self.precision = None
self.block_size = 16
self.accuracy_level = None
def set_quantization_params(self, block_size: int, accuracy_level: int | None):
self.block_size = block_size
self.accuracy_level = accuracy_level
def init_attn_type_and_precision(self, attn_op_type: AttentionOpType, precision: Precision):
self.attn_op_type = attn_op_type
self.precision = precision
def erase_onnx_model(self, onnx_path: str) -> None:
assert onnx_path.endswith(".onnx")
if not os.path.exists(onnx_path):
return
model = onnx.load_model(onnx_path, load_external_data=False)
onnx_data_path = None
for initializer in model.graph.initializer:
if initializer.data_location == 1 and initializer.external_data[0].key == "location":
onnx_data_path = "./" + initializer.external_data[0].value
break
logging.info(f"Erasing {onnx_path}...")
os.remove(onnx_path)
if onnx_data_path is not None:
onnx_data_path = os.path.join(Path(onnx_path).parent, onnx_data_path)
logging.info(f"Erasing {onnx_data_path}...")
os.remove(onnx_data_path)
def get_phi2_torch_model(self):
logging.info("Loading phi2 torch model...")
if self.phi_model is not None:
return
self.phi_model = AutoModelForCausalLM.from_pretrained(
self.model_class, trust_remote_code=True, cache_dir=self.cache_dir
)
self.phi_model.eval()
self.phi_model.to(self.device)
def get_phi2_torch_inputs(self, batch_size: int, sequence_length: int):
input_ids = torch.randint(
low=0,
high=self.phi_config.vocab_size,
size=(batch_size, sequence_length),
dtype=torch.int64,
device=self.device,
)
self.get_phi2_torch_model()
torch_inputs = self.phi_model.prepare_inputs_for_generation(
input_ids, past_key_values=self.phi_model(input_ids, use_cache=True)["past_key_values"]
)
return torch_inputs["input_ids"], torch_inputs["attention_mask"], torch_inputs["past_key_values"]
def dynamo_export(self, onnx_path: str):
input_ids, attention_mask, past_key_values = self.get_phi2_torch_inputs(self.batch_size, self.sequence_length)
self.phi_model(input_ids, attention_mask=attention_mask, past_key_values=past_key_values)
from torch._dynamo import config
config.capture_scalar_outputs = True
logging.info("Exporting Phi2 torch model to ONNX...")
torch.onnx.dynamo_export(
self.phi_model,
input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
export_options=torch.onnx.ExportOptions(dynamic_shapes=True),
).save(onnx_path)
onnx.checker.check_model(onnx_path)
onnx.shape_inference.infer_shapes_path(onnx_path)
def optimize_phi2_onnx(self, onnx_path: str, onnx_path_opt: str):
from fusion_options import FusionOptions
from optimizer import optimize_model
optimization_options = FusionOptions("phi")
optimization_options.set_attention_op_type(self.attn_op_type)
optimizer = optimize_model(
onnx_path,
model_type="phi",
num_heads=self.phi_config.num_attention_heads,
hidden_size=self.phi_config.hidden_size,
opt_level=0,
optimization_options=optimization_options,
only_onnxruntime=False,
)
fused_op_count = optimizer.get_fused_operator_statistics()
if optimizer.is_fully_optimized(fused_op_count):
logging.info("Model is fully optimized.")
else:
logging.info("Model is not fully optimized.")
if self.precision == Precision.FLOAT32:
optimizer.save_model_to_file(onnx_path_opt, use_external_data_format=True)
return
if (
self.precision == Precision.FLOAT16 or self.precision == Precision.INT4
) and self.attn_op_type != AttentionOpType.MultiHeadAttention:
# We keep last three layers of Attention as float32 or bfloat16 to avoid overflow.
node_block_list = [
"GroupQueryAttention_29",
"GroupQueryAttention_30",
"GroupQueryAttention_31",
"Attention_29",
"Attention_30",
"Attention_31",
]
logging.info("Converting onnx model to float16/bfloat16...")
optimizer.convert_float_to_float16(
keep_io_types=False,
node_block_list=node_block_list,
use_symbolic_shape_infer=True,
use_bfloat16_as_blocked_nodes_dtype=self.attn_op_type == AttentionOpType.GroupQueryAttention,
)
logging.info("Converting onnx model to float16/bfloat16 done.")
if self.precision == Precision.FLOAT16:
optimizer.save_model_to_file(onnx_path_opt, use_external_data_format=True)
return
else:
assert self.precision == Precision.INT4
quant = MatMul4BitsQuantizer(
model=optimizer.model,
block_size=self.block_size,
is_symmetric=True,
accuracy_level=self.accuracy_level,
)
quant.process()
quant.model.save_model_to_file(onnx_path_opt, use_external_data_format=True)
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
"--fp32_cpu",
required=False,
action="store_true",
help="Generate fp32 ONNX model for CPU",
)
parser.add_argument(
"--int4_cpu",
required=False,
action="store_true",
help="Generate int4 ONNX model for CPU",
)
parser.add_argument(
"--fp32_gpu",
required=False,
action="store_true",
help="Generate fp32 ONNX model for Nvidia GPUs",
)
parser.add_argument(
"--fp16_gpu",
required=False,
action="store_true",
help="Generate fp16 ONNX model for Nvidia GPUs",
)
parser.add_argument(
"--int4_gpu",
required=False,
action="store_true",
help="Generate int4 ONNX model for Nvidia GPUs",
)
parser.add_argument(
"--fp16_gpu_sm8x",
required=False,
action="store_true",
help="Generate fp16 ONNX model for Nvidia GPUs with CUDA architecture SM=80~89",
)
parser.add_argument(
"--int4_gpu_sm8x",
required=False,
action="store_true",
help="Generate int4 ONNX model for Nvidia GPUs with CUDA architecture SM=80~89",
)
parser.add_argument(
"--overwrite",
required=False,
action="store_true",
help="Overwrite existing ONNX models",
)
parser.add_argument(
"--cache_dir",
required=False,
type=str,
default="./cache",
help="The cache directory for the pytorch model",
)
parser.add_argument(
"--device_id",
required=False,
type=int,
default=0,
help="The device id for the pytorch model",
)
parser.add_argument(
"--run_example",
required=False,
action="store_true",
help="Run ORT inference example",
)
parser.add_argument(
"--skip_export",
required=False,
action="store_true",
help="Skip exporting ONNX model",
)
parser.add_argument(
"--output_dir",
type=str,
help="The output directory for the ONNX models",
default="phi2_onnx_models",
)
parser.add_argument(
"--block_size",
required=False,
default=16,
type=int,
help="Block size to quantize with. See https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py for details.",
)
parser.add_argument(
"--int4_accuracy_level",
required=False,
type=int,
help="Accuracy level of the 4-bit quantized MatMul computation. "
"Refer to the MatMulNBits contrib op's 'accuracy_level' attribute for details "
"(https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftmatmulnbits).",
)
args = parser.parse_args()
return args
def main():
args = parse_arguments()
device = torch.device("cuda", args.device_id) if torch.cuda.is_available() else torch.device("cpu")
converter = ConvertPhi2ToONNX(device, cache_dir=args.cache_dir)
converter.set_quantization_params(args.block_size, args.int4_accuracy_level)
output_dir = args.output_dir
if not os.path.exists(output_dir):
os.makedirs(output_dir)
original_onnx_path = os.path.join(output_dir, "phi2_original.onnx")
if not args.skip_export:
if not os.path.exists(original_onnx_path) or args.overwrite:
converter.dynamo_export(original_onnx_path)
model_type_to_args = {
"fp32_cpu": (
AttentionOpType.MultiHeadAttention,
Precision.FLOAT32,
os.path.join(output_dir, "phi2_decoder_fp32_cpu.onnx"),
),
"int4_cpu": (
AttentionOpType.MultiHeadAttention,
Precision.INT4,
os.path.join(output_dir, "phi2_decoder_int4_cpu.onnx"),
),
"fp32_gpu": (
AttentionOpType.Attention,
Precision.FLOAT32,
os.path.join(output_dir, "phi2_decoder_fp32_gpu.onnx"),
),
"fp16_gpu": (
AttentionOpType.Attention,
Precision.FLOAT16,
os.path.join(output_dir, "phi2_decoder_fp16_gpu.onnx"),
),
"int4_gpu": (AttentionOpType.Attention, Precision.INT4, os.path.join(output_dir, "phi2_decoder_int4_gpu.onnx")),
"fp16_gpu_sm8x": (
AttentionOpType.GroupQueryAttention,
Precision.FLOAT16,
os.path.join(output_dir, "phi2_decoder_fp16_gpu_sm8x.onnx"),
),
"int4_gpu_sm8x": (
AttentionOpType.GroupQueryAttention,
Precision.INT4,
os.path.join(output_dir, "phi2_decoder_int4_gpu_sm8x.onnx"),
),
}
if not args.skip_export:
from multiprocessing import Process
def run_optimize_phi2_onnx(
converter: ConvertPhi2ToONNX,
original_onnx_path: str,
attention_type: AttentionOpType,
precision: Precision,
optimized_onnx_path: str,
):
converter.init_attn_type_and_precision(attention_type, precision)
converter.optimize_phi2_onnx(original_onnx_path, optimized_onnx_path)
processes = []
if args.fp32_cpu:
processes.append(
Process(
target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["fp32_cpu"])
)
)
if args.int4_cpu:
processes.append(
Process(
target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["int4_cpu"])
)
)
if args.fp32_gpu:
processes.append(
Process(
target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["fp32_gpu"])
)
)
if args.fp16_gpu:
processes.append(
Process(
target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["fp16_gpu"])
)
)
if args.int4_gpu:
processes.append(
Process(
target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["int4_gpu"])
)
)
if args.fp16_gpu_sm8x:
processes.append(
Process(
target=run_optimize_phi2_onnx,
args=(converter, original_onnx_path, *model_type_to_args["fp16_gpu_sm8x"]),
)
)
if args.int4_gpu_sm8x:
processes.append(
Process(
target=run_optimize_phi2_onnx,
args=(converter, original_onnx_path, *model_type_to_args["int4_gpu_sm8x"]),
)
)
[p.start() for p in processes]
[p.join() for p in processes]
if args.run_example:
from inference_example import run_phi2
if args.fp16_gpu_sm8x:
logging.info("Running fp16_gpu_sm8x example...")
run_phi2(
onnx_model_path=model_type_to_args["fp16_gpu_sm8x"][2],
use_buffer_share=True,
device_id=args.device_id,
use_step=True,
)
if args.int4_gpu_sm8x:
logging.info("Running int4_gpu_sm8x example...")
run_phi2(
onnx_model_path=model_type_to_args["int4_gpu_sm8x"][2],
use_buffer_share=True,
device_id=args.device_id,
use_step=True,
)
if args.fp32_gpu:
logging.info("Running fp32_gpu example...")
run_phi2(
onnx_model_path=model_type_to_args["fp32_gpu"][2],
use_buffer_share=False,
device_id=args.device_id,
packed_kv=True,
use_fp16=False,
)
if args.fp16_gpu:
logging.info("Running fp16_gpu example...")
run_phi2(
onnx_model_path=model_type_to_args["fp16_gpu"][2],
use_buffer_share=False,
device_id=args.device_id,
packed_kv=True,
)
if args.int4_gpu:
logging.info("Running int4_gpu example...")
run_phi2(
onnx_model_path=model_type_to_args["int4_gpu"][2],
use_buffer_share=False,
device_id=args.device_id,
packed_kv=True,
)
if args.fp32_cpu or args.int4_cpu:
raise NotImplementedError("CPU inference example is not implemented yet.")
if __name__ == "__main__":
main()

Просмотреть файл

@ -0,0 +1,215 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import numpy as np
import torch
from transformers import AutoTokenizer
import onnxruntime as ort
pt_to_np = {
"torch.int32": np.int32,
"torch.int64": np.int64,
"torch.float32": np.float32,
"torch.float16": np.float16,
}
class ORTGenerator:
def __init__(self, decoder_path):
self.onnx_decoder_path = decoder_path
self.num_heads = 32
self.head_size = 80
self.num_layers = 32
self.max_sequence_length = 2048
def get_initial_inputs_and_outputs(self, encodings_dict):
self.torch_dtype = torch.float16 if self.use_fp16 else torch.float32
input_ids = torch.tensor(encodings_dict["input_ids"], device=self.device, dtype=torch.int32)
attention_mask = torch.tensor(encodings_dict["attention_mask"], device=self.device, dtype=torch.int32)
step = torch.tensor([0], device=self.device, dtype=torch.int64)
inputs = {
"input_ids": input_ids.contiguous(),
"attention_mask": attention_mask.contiguous(),
}
if self.use_step:
inputs["step"] = step.contiguous()
batch_size, sequence_length = input_ids.shape
past_seq_length = self.max_sequence_length if self.use_buffer_share else 0
past_shape = (
(2, batch_size, self.num_heads, past_seq_length, self.head_size)
if self.packed_kv
else (batch_size, self.num_heads, past_seq_length, self.head_size)
)
for i in range(self.num_layers):
past = torch.zeros(past_shape, device=self.device, dtype=self.torch_dtype)
inputs.update(
{f"past_key_{i}": past.contiguous(), f"past_value_{i}": past.clone().contiguous()}
) if not self.packed_kv else inputs.update({f"past_{i}": past.contiguous()})
logits = torch.zeros(batch_size, sequence_length, 51200, device=self.device, dtype=self.torch_dtype)
outputs = {"logits": logits.contiguous()}
if not self.use_buffer_share:
present_shape = (
(2, batch_size, self.num_heads, sequence_length, self.head_size)
if self.packed_kv
else (batch_size, self.num_heads, sequence_length, self.head_size)
)
for i in range(self.num_layers):
present = torch.zeros(present_shape, device=self.device, dtype=self.torch_dtype)
outputs.update(
{f"present_key_{i}": present.contiguous(), f"present_value_{i}": present.contiguous()}
) if not self.packed_kv else outputs.update({f"present_{i}": present.contiguous()})
return inputs, outputs
def apply_io_binding(self, model: ort.InferenceSession, inputs: dict, outputs: dict):
io_binding = model.io_binding()
device = None
for k, v in inputs.items():
io_binding.bind_input(
name=k,
device_type=v.device.type,
device_id=0 if v.device.type == "cpu" else v.device.index,
element_type=pt_to_np[repr(v.dtype)],
shape=tuple(v.shape),
buffer_ptr=v.data_ptr(),
)
device = v.device
for output in model.get_outputs():
name = output.name
if self.use_buffer_share and "present" in name:
v = inputs[name.replace("present", "past")]
io_binding.bind_output(
name=name,
device_type=v.device.type,
device_id=v.device.index,
element_type=(np.float16 if self.use_fp16 else np.float32),
shape=tuple(v.shape),
buffer_ptr=v.data_ptr(),
)
else:
v = outputs[name]
io_binding.bind_output(
name=name,
device_type=device.type,
device_id=0 if device.type == "cpu" else device.index,
element_type=(np.float16 if self.use_fp16 else np.float32),
shape=tuple(v.shape),
buffer_ptr=v.data_ptr(),
)
return io_binding
def create_session(self, device_id, use_fp16=True, use_buffer_share=True, packed_kv=False, use_step=False):
sess_options = ort.SessionOptions()
ep = ("CUDAExecutionProvider", {"device_id": device_id}) if device_id >= 0 else "CPUExecutionProvider"
self.sess = ort.InferenceSession(self.onnx_decoder_path, sess_options=sess_options, providers=[ep])
self.device = torch.device("cuda", device_id) if torch.cuda.is_available() else torch.device("cpu")
self.use_fp16 = use_fp16
self.use_buffer_share = use_buffer_share
self.packed_kv = packed_kv
self.use_step = use_step
self.tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
self.tokenizer.pad_token = "[PAD]"
def generate(self, prompt, max_length):
encodings_dict = self.tokenizer.batch_encode_plus(prompt, padding=True)
inputs, outputs = self.get_initial_inputs_and_outputs(encodings_dict)
all_token_ids = inputs["input_ids"].clone()
batch_size, sequence_length = all_token_ids.shape
current_length = sequence_length
has_eos = torch.zeros(batch_size, device=self.device, dtype=torch.bool)
while current_length < max_length:
io_binding = self.apply_io_binding(self.sess, inputs, outputs)
io_binding.synchronize_inputs()
self.sess.run_with_iobinding(io_binding)
io_binding.synchronize_outputs()
# Sample with argmax (greedy search)
next_token_logits = outputs["logits"][:, -1, :]
next_tokens = torch.argmax(next_token_logits, dim=-1)
# Check if we previously reached EOS token id or if generated token id is EOS token id
has_eos = has_eos | next_tokens == self.tokenizer.eos_token_id
# Determine which new tokens to add to list of all token ids
# Add EOS token ids for batch entries that ended early (ragged batching scenario where some batch entries ended early and some haven't)
tokens_to_add = next_tokens.masked_fill(has_eos, self.tokenizer.eos_token_id).reshape([batch_size, 1])
all_token_ids = torch.cat([all_token_ids, tokens_to_add], dim=-1)
# Return early if all batch entries have reached EOS token id
if torch.all(has_eos):
break
# Update inputs for next inference run
current_length += 1
inputs["input_ids"] = tokens_to_add.to(torch.int32)
if self.use_step:
inputs["step"] = torch.tensor([current_length - 1], device=self.device, dtype=torch.int64)
inputs["attention_mask"] = torch.cat([inputs["attention_mask"], (~has_eos).reshape(batch_size, 1)], 1).to(
torch.int32
)
# Set logits to zeros for next inference run and re-use memory buffer
if outputs["logits"].shape[1] != 1:
outputs["logits"] = outputs["logits"][:, :1, :].contiguous()
outputs["logits"].zero_()
if not self.use_buffer_share:
for i in range(self.num_layers):
if not self.packed_kv:
inputs[f"past_key_{i}"] = outputs[f"present_key_{i}"]
inputs[f"past_value_{i}"] = outputs[f"present_value_{i}"]
else:
inputs[f"past_{i}"] = outputs[f"present_{i}"]
new_sequence_length = inputs["attention_mask"].shape[1]
present_shape = (
(2, batch_size, self.num_heads, new_sequence_length, self.head_size)
if self.packed_kv
else (batch_size, self.num_heads, new_sequence_length, self.head_size)
)
for i in range(self.num_layers):
present = torch.zeros(present_shape, device=self.device, dtype=self.torch_dtype)
outputs.update(
{f"present_key_{i}": present.contiguous(), f"present_value_{i}": present.clone().contiguous()}
) if not self.packed_kv else outputs.update({f"present_{i}": present.contiguous()})
texts = self.tokenizer.batch_decode(all_token_ids, skip_special_tokens=True)
return texts
def run_phi2(onnx_model_path, use_buffer_share, device_id, packed_kv=False, use_fp16=True, use_step=False):
prompt = [
'''```python
def print_prime(n):
"""
Print all primes between 1 and n
"""'''
]
generator = ORTGenerator(onnx_model_path)
generator.create_session(device_id, use_fp16, use_buffer_share, packed_kv, use_step)
texts = generator.generate(prompt, max_length=200)
for i in range(len(texts)):
print("Prompt: ", prompt[i])
print("Texts: ", texts[i])

Просмотреть файл

@ -0,0 +1,3 @@
onnx>=1.15.0
transformers>=4.36.2
onnxscript>=0.1.0.dev20240126

Просмотреть файл

@ -82,6 +82,10 @@ class OnnxModel:
output_name_to_node[output_name] = node
return output_name_to_node
def functions(self):
all_functions = [list(self.model.functions)]
return all_functions
def nodes(self):
all_nodes = []
for graph in self.graphs():
@ -733,6 +737,7 @@ class OnnxModel:
"node_block_list",
"force_fp16_initializers",
"force_fp16_inputs",
"use_bfloat16_as_blocked_nodes_dtype",
]
if key in kwargs
}

Просмотреть файл

@ -0,0 +1,839 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from logging import getLogger
from typing import List, Optional
import numpy as np
from dynamo_onnx_helper import DynamoOnnxHelper
from fusion_base import Fusion
from fusion_options import AttentionOpType, FusionOptions
from fusion_skiplayernorm import FusionBiasSkipLayerNormalization, FusionSkipLayerNormalization
from fusion_utils import NumpyHelper
from onnx import ModelProto, NodeProto, TensorProto, helper, numpy_helper
from onnx_model import OnnxModel
logger = getLogger(__name__)
class ProcessGemmWFunc:
def __call__(self, x):
return np.transpose(x, (1, 0))
class ProcessMatMulQFunc:
def __call__(self, x):
return np.transpose(np.split(x, 3, 0)[0], (1, 0))
class ProcessMatMulKFunc:
def __call__(self, x):
return np.transpose(np.split(x, 3, 0)[1], (1, 0))
class ProcessMatMulVFunc:
def __call__(self, x):
return np.transpose(np.split(x, 3, 0)[2], (1, 0))
class ProcessBiasQFunc:
def __call__(self, x):
x = np.split(x, 3, -1)[0]
return x
class ProcessBiasKFunc:
def __call__(self, x):
x = np.split(x, 3, -1)[1]
return x
class ProcessBiasVFunc:
def __call__(self, x):
x = np.split(x, 3, -1)[2]
return x
class ProcessRotCacheFunc:
def __call__(self, x):
# half rotary embedding
assert len(x.shape) == 2
if x.shape[1] == 32:
return x[:, 0:16]
return x
# TODO: move to a seperate file
class Fission(Fusion):
def __init__(
self,
model: OnnxModel,
nodes_to_find: List[str],
):
super().__init__(model, "DONOTUSE", nodes_to_find)
def set_attention_op_type(self, attn_op_type: AttentionOpType):
self.attn_op_type = attn_op_type
def get_uname(self, layer_id, name):
return name + "_" + str(layer_id)
def get_io_by_name(self, node, name):
for input in node.input:
if input == name or input.endswith(name) or input.startswith(name):
return input
for output in node.output:
if output == name or output.endswith(name) or output.startswith(name):
return output
raise Exception(f"input {name} not found in node {node.name}")
def process_initializer(self, initializer_name, functor, custom_name=None):
i = self.model.get_initializer(initializer_name)
i_np_array = NumpyHelper.to_array(i)
processed_i_np_array = functor(i_np_array)
new_tensor = helper.make_tensor(
initializer_name + "_processed" if custom_name is None else custom_name,
data_type=TensorProto.FLOAT,
dims=processed_i_np_array.shape,
vals=processed_i_np_array.flatten().tobytes(),
raw=True,
)
self.model.add_initializer(new_tensor, self.this_graph_name)
return new_tensor.name
def add_fp32_value_info(self, name):
new_value_info = self.model.graph().value_info.add()
new_value_info.name = name
new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT
def add_int64_value_info(self, name):
new_value_info = self.model.graph().value_info.add()
new_value_info.name = name
new_value_info.type.tensor_type.elem_type = TensorProto.INT64
def replace_fp32_value_info(self, name, shape):
for value_info in self.model.graph().value_info:
if value_info.name == name:
self.model.graph().value_info.remove(value_info)
break
new_value_info = helper.make_tensor_value_info(
name,
elem_type=TensorProto.FLOAT,
shape=shape,
)
self.model.graph().value_info.extend([new_value_info])
def set_unique_name_and_add_nodes(
self, subgraph_nodes: List[NodeProto], layer_id: int, layer_known_edges_names: List[str]
):
for new_node in subgraph_nodes:
for i, name in enumerate(new_node.input):
if name == "":
continue
elif name not in layer_known_edges_names:
new_node.input[i] = self.get_uname(layer_id, name)
self.add_fp32_value_info(new_node.input[i])
for i, name in enumerate(new_node.output):
if name == "":
continue
elif name not in layer_known_edges_names:
new_node.output[i] = self.get_uname(layer_id, name)
self.add_fp32_value_info(new_node.output[i])
new_node.name = self.get_uname(layer_id, new_node.name)
self.nodes_to_add.append(new_node)
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
def layernorm(self, inputs: List[str], outputs: List[str], prefix: str = ""):
assert len(inputs) == 3
assert len(outputs) == 1
node = helper.make_node(
"LayerNormalization",
inputs=inputs,
outputs=outputs,
name=prefix + "_LayerNormalization",
epsilon=9.999999747378752e-06,
)
return [node]
def gemm(self, inputs: List[str], outputs: List[str], prefix: str = ""):
assert len(inputs) == 3
assert len(outputs) == 1
matmul = helper.make_node(
"MatMul",
inputs=[inputs[0], inputs[1]],
outputs=[prefix + "matmul_out"],
name=prefix + "MatMul",
)
add = helper.make_node(
"Add",
inputs=[prefix + "matmul_out", inputs[2]],
outputs=outputs,
name=prefix + "Bias",
)
return [matmul, add]
def rotary(self, inputs: List[str], outputs: List[str], prefix: str = "", rot_dim=32, num_heads=32):
assert len(inputs) == 4
assert len(outputs) == 1
node = helper.make_node(
"RotaryEmbedding",
inputs=inputs,
outputs=outputs,
name=prefix + "RotaryEmbedding",
domain="com.microsoft",
rotary_embedding_dim=rot_dim,
num_heads=num_heads,
)
return [node]
def fastgelu(self, inputs: List[str], outputs: List[str], prefix: str = ""):
assert len(inputs) == 1
assert len(outputs) == 1
node = helper.make_node(
"FastGelu",
inputs=inputs,
outputs=outputs,
name=prefix + "FastGelu",
domain="com.microsoft",
)
return [node]
def add(self, inputs: List[str], outputs: List[str], prefix: str = ""):
assert len(inputs) == 2
assert len(outputs) == 1
node = helper.make_node(
"Add",
inputs=inputs,
outputs=outputs,
name=prefix + "Add",
)
return [node]
def mha(self, inputs: List[str], outputs: List[str], prefix: str = "", num_heads=32):
assert len(inputs) == 8
assert len(outputs) == 3
node = helper.make_node(
"MultiHeadAttention",
inputs=inputs,
outputs=outputs,
name=prefix + "MultiHeadAttention",
domain="com.microsoft",
num_heads=num_heads,
unidirectional=1,
)
return [node]
def gqa(self, inputs: List[str], outputs: List[str], prefix: str = "", num_heads=32):
assert len(inputs) == 7
assert len(outputs) == 3
node = helper.make_node(
"GroupQueryAttention",
inputs=inputs,
outputs=outputs,
name=prefix + "GroupQueryAttention",
domain="com.microsoft",
num_heads=num_heads,
kv_num_heads=num_heads,
)
return [node]
def attention(self, inputs: List[str], outputs: List[str], prefix: str = "", num_heads=32):
assert len(inputs) == 5
assert len(outputs) == 2
node = helper.make_node(
"Attention",
inputs=inputs,
outputs=outputs,
name=prefix + "Attention",
domain="com.microsoft",
num_heads=num_heads,
unidirectional=1,
do_rotary=1,
rotary_embedding_dim=32,
)
return [node]
class Phi2PreProcessor(DynamoOnnxHelper):
def __init__(self, model: ModelProto, num_heads: int, hidden_size: int):
super().__init__(model)
self.num_hidden_layers = 32
self.num_attention_heads = num_heads
self.hidden_size = hidden_size
self.phi2_edge_dict = self.get_phi2_edge_dict()
self.func_name = "modeling_phi_PhiModel_model_1"
def get_phi2_edge_dict(self) -> dict:
edge_dict = {}
edge_dict["lm_head_1"] = "logits"
edge_dict["l_input_ids_"] = "input_ids"
edge_dict["key_states"] = "past_key_0"
edge_dict["value_states"] = "past_value_0"
for i in range(self.num_hidden_layers):
edge_dict[f"key_states_{i}"] = f"past_key_{i}"
edge_dict[f"value_states_{i}"] = f"past_value_{i}"
edge_dict[f"model_layers_{i}_1"] = f"present_key_{i}"
edge_dict[f"model_layers_{i}_1_1"] = f"present_value_{i}"
return edge_dict
def simplify_phi2_op_type(self):
phi2_transformer_layer_name = "modeling_phi_PhiDecoderLayer_model_layers"
for node in self.model.graph.node:
index = node.op_type.find(phi2_transformer_layer_name)
if index != -1:
node.op_type = node.op_type[index:]
def process_graph_io(self, attn_op_type: AttentionOpType):
self.use_attn = attn_op_type == AttentionOpType.Attention
graph = self.model.graph
new_inputs = []
for vi in graph.input:
if "input_ids" in vi.name:
vi_iid = helper.make_tensor_value_info(
vi.name,
elem_type=TensorProto.INT32,
shape=["batch_size", "seq_len"],
)
vi_pid = helper.make_tensor_value_info(
"step",
elem_type=TensorProto.INT64,
shape=[1],
)
vi_mask = helper.make_tensor_value_info(
"attention_mask",
elem_type=TensorProto.INT32,
shape=["batch_size", "seq_len"],
)
new_inputs.extend([vi_iid, vi_pid, vi_mask])
if not self.use_attn:
if "past_key" in vi.name or "past_value" in vi.name:
vi_cache = helper.make_tensor_value_info(
vi.name,
elem_type=vi.type.tensor_type.elem_type,
shape=[
"batch_size",
self.num_attention_heads,
"past_seq_len",
self.hidden_size // self.num_attention_heads,
],
)
new_inputs.extend([vi_cache])
else:
if "past_key" in vi.name:
vi_cache = helper.make_tensor_value_info(
vi.name.replace("past_key", "past"),
elem_type=vi.type.tensor_type.elem_type,
shape=[
2,
"batch_size",
self.num_attention_heads,
"past_seq_len",
self.hidden_size // self.num_attention_heads,
],
)
new_inputs.extend([vi_cache])
graph.ClearField("input")
graph.input.extend(new_inputs)
new_outputs = []
for i, vi in enumerate(graph.output):
if i == 0:
new_outputs.extend([vi])
else:
if not self.use_attn:
vi_cache = helper.make_tensor_value_info(
vi.name,
elem_type=vi.type.tensor_type.elem_type,
shape=[
"batch_size",
self.num_attention_heads,
"total_seq_len",
self.hidden_size // self.num_attention_heads,
],
)
new_outputs.extend([vi_cache])
else:
if "present_key" in vi.name:
vi_cache = helper.make_tensor_value_info(
vi.name.replace("present_key", "present"),
elem_type=vi.type.tensor_type.elem_type,
shape=[
2,
"batch_size",
self.num_attention_heads,
"total_seq_len",
self.hidden_size // self.num_attention_heads,
],
)
new_outputs.extend([vi_cache])
graph.ClearField("output")
graph.output.extend(new_outputs)
def preprocess_onnx(self, attn_op_type: AttentionOpType):
function_name = None
for func in self.model.functions:
if func.name.endswith(self.func_name):
function_name = func.name
break
assert function_name is not None
self.unroll_function(function_name)
self.update_edges(self.phi2_edge_dict)
self.simplify_phi2_op_type()
self.remove_dropout_layer()
self.process_graph_io(attn_op_type)
class FissionTransformerEmbeddingPhi(Fission):
def __init__(
self,
model: OnnxModel,
):
super().__init__(model, ["torch_nn_modules_sparse_Embedding_model_embed_tokens_1"])
def fuse(self, node, input_name_to_nodes, output_name_to_node):
logger.info("Optimizing %s...", node.name)
assert len(node.input) == 2
assert len(node.output) == 1
input = node.input[0]
output = node.output[0]
embedding = self.get_io_by_name(node, "embed_tokens.weight")
layer_known_edges_names = [input, output, embedding]
subgraph_nodes = [
helper.make_node(
"Gather",
inputs=[embedding, input],
outputs=[output],
name="Embedding_Gather",
),
]
self.set_unique_name_and_add_nodes(subgraph_nodes, 0, layer_known_edges_names)
self.nodes_to_remove.append(node)
self.prune_graph = True
class FissionTransformerLayerNormPhi(Fission):
def __init__(
self,
model: OnnxModel,
):
super().__init__(model, ["torch_nn_modules_normalization_LayerNorm_model_final_layernorm_1"])
def fuse(self, node, input_name_to_nodes, output_name_to_node):
logger.info("Optimizing %s...", node.name)
assert len(node.input) == 3
assert len(node.output) == 1
input = node.input[0]
output = node.output[0]
ln_weight = self.get_io_by_name(node, "final_layernorm.weight")
ln_bias = self.get_io_by_name(node, "final_layernorm.bias")
layer_known_edges_names = [input, output, ln_weight, ln_bias]
subgraph_nodes = []
subgraph_nodes.extend(self.layernorm([input, ln_weight, ln_bias], [output], "Final"))
self.set_unique_name_and_add_nodes(subgraph_nodes, 99, layer_known_edges_names)
self.replace_fp32_value_info(input, ["batch_size", "seq_len", "hidden_size"])
self.replace_fp32_value_info(output, ["batch_size", "seq_len", "hidden_size"])
self.nodes_to_remove.append(node)
self.prune_graph = True
class FissionTransformerCausalLMHeadPhi(Fission):
def __init__(
self,
model: OnnxModel,
):
super().__init__(model, ["torch_nn_modules_linear_Linear_lm_head_1"])
def fuse(self, node, input_name_to_nodes, output_name_to_node):
logger.info("Optimizing %s...", node.name)
assert len(node.input) == 5
assert len(node.output) == 1
input = node.input[2]
output = node.output[0]
fc_weight = self.process_initializer(self.get_io_by_name(node, "lm_head.weight"), ProcessGemmWFunc())
fc_bias = self.get_io_by_name(node, "lm_head.bias")
layer_known_edges_names = [input, output, fc_weight, fc_bias]
subgraph_nodes = []
subgraph_nodes.extend(self.gemm([input, fc_weight, fc_bias], [output], "LMHead_"))
self.set_unique_name_and_add_nodes(subgraph_nodes, 99, layer_known_edges_names)
self.replace_fp32_value_info(input, ["batch_size", "seq_len", "hidden_size"])
self.replace_fp32_value_info(output, ["batch_size", "seq_len", 51200])
self.nodes_to_remove.append(node)
self.prune_graph = True
class FissionTransformerBlockPhi(Fission):
def __init__(
self,
model: OnnxModel,
num_heads: int,
):
self.num_heads = num_heads
max_num_layers = 32
self.func_to_layer_id = {}
nodes_to_find = []
for layer in range(max_num_layers):
func_name = f"modeling_phi_PhiDecoderLayer_model_layers_{layer}_1"
nodes_to_find.append(func_name)
self.func_to_layer_id[func_name] = layer
super().__init__(model, nodes_to_find)
def get_layer_id(self, node):
return self.func_to_layer_id[node.op_type]
def get_gqa_aux_nodes(self):
gqa_aux_nodes = [
helper.make_node(
"Cast",
inputs=["attention_mask"],
outputs=["mask_int64"],
name="Cast_gqa_aux_0",
to=TensorProto.INT64,
),
helper.make_node(
"ReduceSum",
inputs=["mask_int64", "one"],
outputs=["mask_row_sums"],
name="ReduceSum_gqa_aux",
),
helper.make_node(
"Sub",
inputs=["mask_row_sums", "one"],
outputs=["seqlens_k_int64"],
name="Sub_gqa_aux",
),
helper.make_node(
"Cast",
inputs=["seqlens_k_int64"],
outputs=["seqlens_k"],
name="Cast_gqa_aux_1",
to=TensorProto.INT32,
),
helper.make_node("Shape", inputs=["mask_int64"], outputs=["mask_shape"], name="Shape_gqa_aux_0"),
helper.make_node(
"Gather",
inputs=["mask_shape", "one"],
outputs=["total_seq_len_int64"],
name="Gather_gqa_aux_0",
axis=0,
),
helper.make_node(
"Cast",
inputs=["total_seq_len_int64"],
outputs=["total_sequence_length"],
name="Cast_gqa_aux_2",
to=TensorProto.INT32,
),
]
return gqa_aux_nodes
def pack_qkv_gemm(self, q_w, k_w, v_w, q_b, k_b, v_b, weight_name, bias_name):
q_weight = self.model.get_initializer(q_w)
k_weight = self.model.get_initializer(k_w)
v_weight = self.model.get_initializer(v_w)
qw = np.transpose(NumpyHelper.to_array(q_weight), (1, 0))
kw = np.transpose(NumpyHelper.to_array(k_weight), (1, 0))
vw = np.transpose(NumpyHelper.to_array(v_weight), (1, 0))
qkv_weight = np.stack((qw, kw, vw), axis=1)
q_bias = self.model.get_initializer(q_b)
k_bias = self.model.get_initializer(k_b)
v_bias = self.model.get_initializer(v_b)
qb = NumpyHelper.to_array(q_bias)
kb = NumpyHelper.to_array(k_bias)
vb = NumpyHelper.to_array(v_bias)
qkv_bias = np.stack((qb, kb, vb), axis=0)
hidden_size = qkv_weight.shape[0]
weight = helper.make_tensor(
weight_name,
data_type=TensorProto.FLOAT,
dims=[hidden_size, hidden_size * 3],
vals=qkv_weight.flatten().tobytes(),
raw=True,
)
self.model.add_initializer(weight, self.this_graph_name)
bias = helper.make_tensor(
bias_name,
data_type=TensorProto.FLOAT,
dims=[hidden_size * 3],
vals=qkv_bias.flatten().tobytes(),
raw=True,
)
self.model.add_initializer(bias, self.this_graph_name)
self.add_fp32_value_info(weight.name)
self.add_fp32_value_info(bias.name)
return weight_name, bias_name
def fuse(
self,
node,
input_name_to_nodes,
output_name_to_node,
):
logger.info("Optimizing %s...", node.name)
logger.info(f"AttentionOpType: {self.attn_op_type}")
layer_id = self.get_layer_id(node)
i_hidden_states = node.input[0]
i_key_cache = self.get_io_by_name(node, "past_key")
i_value_cache = self.get_io_by_name(node, "past_value")
o_hidden_states = node.output[3]
o_key_cache = self.get_io_by_name(node, "present_key")
o_value_cache = self.get_io_by_name(node, "present_value")
ln_weight = self.get_io_by_name(node, "input_layernorm.weight")
ln_bias = self.get_io_by_name(node, "input_layernorm.bias")
attn_q_weight, attn_q_bias, attn_k_weight, attn_k_bias, attn_v_weight, attn_v_bias = (
None,
None,
None,
None,
None,
None,
)
attn_qkv_weight, attn_qkv_bias = None, None
cos_cache, sin_cache = None, None
if self.attn_op_type != AttentionOpType.Attention:
attn_q_weight = self.process_initializer(
self.get_io_by_name(node, "self_attn.q_proj.weight"), ProcessGemmWFunc()
)
attn_k_weight = self.process_initializer(
self.get_io_by_name(node, "self_attn.k_proj.weight"), ProcessGemmWFunc()
)
attn_v_weight = self.process_initializer(
self.get_io_by_name(node, "self_attn.v_proj.weight"), ProcessGemmWFunc()
)
attn_q_bias = self.get_io_by_name(node, "self_attn.q_proj.bias")
attn_k_bias = self.get_io_by_name(node, "self_attn.k_proj.bias")
attn_v_bias = self.get_io_by_name(node, "self_attn.v_proj.bias")
cos_cache = self.process_initializer(
self.get_io_by_name(node, "rotary_emb.cos_cached"), ProcessRotCacheFunc()
)
sin_cache = self.process_initializer(
self.get_io_by_name(node, "rotary_emb.sin_cached"), ProcessRotCacheFunc()
)
else:
attn_qkv_weight, attn_qkv_bias = self.pack_qkv_gemm(
self.get_io_by_name(node, "self_attn.q_proj.weight"),
self.get_io_by_name(node, "self_attn.k_proj.weight"),
self.get_io_by_name(node, "self_attn.v_proj.weight"),
self.get_io_by_name(node, "self_attn.q_proj.bias"),
self.get_io_by_name(node, "self_attn.k_proj.bias"),
self.get_io_by_name(node, "self_attn.v_proj.bias"),
self.get_uname(layer_id, "attn_qkv_weight"),
self.get_uname(layer_id, "attn_qkv_bias"),
)
attn_out_weight = self.process_initializer(
self.get_io_by_name(node, "self_attn.dense.weight"), ProcessGemmWFunc()
)
attn_out_bias = self.get_io_by_name(node, "self_attn.dense.bias")
mlp_fc1_weight = self.process_initializer(self.get_io_by_name(node, "mlp.fc1.weight"), ProcessGemmWFunc())
mlp_fc2_weight = self.process_initializer(self.get_io_by_name(node, "mlp.fc2.weight"), ProcessGemmWFunc())
mlp_fc1_bias = self.get_io_by_name(node, "mlp.fc1.bias")
mlp_fc2_bias = self.get_io_by_name(node, "mlp.fc2.bias")
layer_known_edges_names = []
layer_known_edges_names.extend([i_hidden_states, i_key_cache, i_value_cache])
layer_known_edges_names.extend([o_hidden_states, o_key_cache, o_value_cache])
layer_known_edges_names.extend([ln_weight, ln_bias])
if self.attn_op_type != AttentionOpType.Attention:
layer_known_edges_names.extend(
[
attn_q_weight,
attn_q_bias,
attn_k_weight,
attn_k_bias,
attn_v_weight,
attn_v_bias,
cos_cache,
sin_cache,
]
)
else:
layer_known_edges_names.extend([attn_qkv_weight, attn_qkv_bias])
layer_known_edges_names.extend(
[attn_out_weight, attn_out_bias, mlp_fc1_weight, mlp_fc1_bias, mlp_fc2_weight, mlp_fc2_bias]
)
layer_known_edges_names.extend(["attention_mask", "step", "seqlens_k", "total_sequence_length"])
subgraph_nodes = []
subgraph_nodes.extend(self.layernorm([i_hidden_states, ln_weight, ln_bias], ["ln_out"]))
subgraph_nodes.extend(self.gemm(["attn_out", attn_out_weight, attn_out_bias], ["attn_add_out"], "OutProj_"))
subgraph_nodes.extend(self.gemm(["ln_out", mlp_fc1_weight, mlp_fc1_bias], ["fc1_out"], "FC1_"))
subgraph_nodes.extend(self.fastgelu(["fc1_out"], ["gelu_out"]))
subgraph_nodes.extend(self.gemm(["gelu_out", mlp_fc2_weight, mlp_fc2_bias], ["fc2_out"], "FC2_"))
subgraph_nodes.extend(self.add(["attn_add_out", "fc2_out"], ["residual_1_out"], "Residual_1"))
subgraph_nodes.extend(self.add([i_hidden_states, "residual_1_out"], [o_hidden_states], "Residual_2"))
if self.attn_op_type != AttentionOpType.Attention:
subgraph_nodes.extend(self.gemm(["ln_out", attn_q_weight, attn_q_bias], ["query"], "Q_"))
subgraph_nodes.extend(self.gemm(["ln_out", attn_k_weight, attn_k_bias], ["key"], "K_"))
subgraph_nodes.extend(self.gemm(["ln_out", attn_v_weight, attn_v_bias], ["value"], "V_"))
subgraph_nodes.extend(self.rotary(["query", "step", cos_cache, sin_cache], ["query_rot"], "Q_"))
subgraph_nodes.extend(self.rotary(["key", "step", cos_cache, sin_cache], ["key_rot"], "K_"))
if self.attn_op_type == AttentionOpType.MultiHeadAttention:
subgraph_nodes.extend(
self.mha(
["query_rot", "key_rot", "value", "", "attention_mask", "", i_key_cache, i_value_cache],
["attn_out", o_key_cache, o_value_cache],
)
)
elif self.attn_op_type == AttentionOpType.GroupQueryAttention:
subgraph_nodes.extend(
self.gqa(
[
"query_rot",
"key_rot",
"value",
i_key_cache,
i_value_cache,
"seqlens_k",
"total_sequence_length",
],
["attn_out", o_key_cache, o_value_cache],
)
)
if layer_id == 0:
gqa_aux_nodes = self.get_gqa_aux_nodes()
for new_node in gqa_aux_nodes:
self.nodes_to_add.append(new_node)
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
self.model.add_initializer(
numpy_helper.from_array(np.array([1], dtype="int64"), name="one"), self.this_graph_name
)
else:
past_name = f"past_{layer_id}"
present_name = f"present_{layer_id}"
layer_known_edges_names.extend([past_name, present_name])
subgraph_nodes.extend(
self.attention(
["ln_out", attn_qkv_weight, attn_qkv_bias, "attention_mask", past_name], ["attn_out", present_name]
)
)
self.set_unique_name_and_add_nodes(subgraph_nodes, layer_id, layer_known_edges_names)
self.replace_fp32_value_info(i_hidden_states, ["batch_size", "seq_len", "hidden_size"])
self.replace_fp32_value_info(o_hidden_states, ["batch_size", "seq_len", "hidden_size"])
self.nodes_to_remove.append(node)
self.prune_graph = True
class PhiOnnxModel(OnnxModel):
def __init__(self, model: ModelProto, num_heads: int, hidden_size: int):
super().__init__(model)
self.phi2_preprocessor = Phi2PreProcessor(self.model, num_heads, hidden_size)
self.fission_transformer_block = FissionTransformerBlockPhi(self, num_heads)
self.fission_causal_lm_head = FissionTransformerCausalLMHeadPhi(self)
self.fission_transformer_layernorm = FissionTransformerLayerNormPhi(self)
self.fission_transformer_embedding = FissionTransformerEmbeddingPhi(self)
def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bool = False):
assert options is not None
attn_op_type = options.attention_op_type
self.fission_transformer_block.set_attention_op_type(attn_op_type)
self.phi2_preprocessor.preprocess_onnx(attn_op_type)
self.fission_transformer_block.apply()
self.fission_transformer_layernorm.apply()
self.fission_causal_lm_head.apply()
self.fission_transformer_embedding.apply()
super().prune_graph()
# SLN ctor is placed here intentionally to delay the symbolic shape inference
self.fuse_sln = FusionSkipLayerNormalization(self)
self.fuse_bias_sln = FusionBiasSkipLayerNormalization(self)
self.fuse_sln.apply()
self.fuse_bias_sln.apply()
def get_fused_operator_statistics(self):
"""
Returns node count of fused operators.
"""
op_count = {}
ops = [
"Attention",
"MultiHeadAttention",
"GroupQueryAttention",
"Gelu",
"BiasGelu",
"FastGelu",
"LayerNormalization",
"SkipLayerNormalization",
]
for op in ops:
nodes = self.get_nodes_by_op_type(op)
op_count[op] = len(nodes)
logger.info(f"Optimized operators: {op_count}")
return op_count
def is_fully_optimized(self, fused_op_count=None):
"""
Returns True when the model is fully optimized.
"""
if fused_op_count is None:
fused_op_count = self.get_fused_operator_statistics()
def op_count(op_name: str):
return fused_op_count.get(op_name) or 0
attention = op_count("Attention") + op_count("MultiHeadAttention") + op_count("GroupQueryAttention")
gelu = op_count("Gelu") + op_count("BiasGelu") + op_count("FastGelu")
layer_norm = op_count("LayerNormalization") + op_count("SkipLayerNormalization")
is_perfect = (attention > 0) and (attention == gelu) and (layer_norm >= attention)
if layer_norm == 0:
logger.debug("Layer Normalization not fused")
if gelu == 0:
logger.debug("Gelu (or FastGelu) not fused")
if attention == 0:
logger.warning("Attention (or MultiHeadAttention) not fused")
return is_perfect

Просмотреть файл

@ -34,6 +34,7 @@ from onnx_model_bert_tf import BertOnnxModelTF
from onnx_model_clip import ClipOnnxModel
from onnx_model_conformer import ConformerOnnxModel
from onnx_model_gpt2 import Gpt2OnnxModel
from onnx_model_phi import PhiOnnxModel
from onnx_model_t5 import T5OnnxModel
from onnx_model_tnlr import TnlrOnnxModel
from onnx_model_unet import UnetOnnxModel
@ -58,6 +59,7 @@ MODEL_TYPES = {
"vae": (VaeOnnxModel, "pytorch", 1), # UAE in Stable Diffusion
"vit": (BertOnnxModel, "pytorch", 1),
"conformer": (ConformerOnnxModel, "pytorch", 1),
"phi": (PhiOnnxModel, "pytorch", 0),
}

Просмотреть файл

@ -419,6 +419,7 @@ packages = [
"onnxruntime.transformers.models.gpt2",
"onnxruntime.transformers.models.llama",
"onnxruntime.transformers.models.longformer",
"onnxruntime.transformers.models.phi2",
"onnxruntime.transformers.models.t5",
"onnxruntime.transformers.models.stable_diffusion",
"onnxruntime.transformers.models.whisper",