зеркало из https://github.com/microsoft/archai.git
chore(nlp): Finishes docstringing ONNX-related files.
This commit is contained in:
Родитель
553033173e
Коммит
ab39978eb3
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
"""ONNX-related configuration utilities.
|
||||
"""
|
|
@ -4,13 +4,21 @@
|
|||
"""GPT-2 ONNX configuration.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
from archai.nlp.onnx.config_utils.onnx_config_base import OnnxConfigWithPast
|
||||
|
||||
|
||||
class GPT2OnnxConfig(OnnxConfigWithPast):
|
||||
def __init__(self, config, task="causal-lm", use_past=False) -> None:
|
||||
"""Implements a GPT-2 ONNX configuration (with past key/values support)."""
|
||||
|
||||
def __init__(
|
||||
self, config: PretrainedConfig, task: Optional[str] = "causal-lm", use_past: Optional[bool] = False
|
||||
) -> None:
|
||||
super().__init__(config, task=task, use_past=use_past, past_key_values=2)
|
||||
|
||||
@property
|
||||
def num_layers(self) -> int:
|
||||
return self.config.n_layer
|
||||
return self.config.n_layer
|
||||
|
|
|
@ -4,28 +4,33 @@
|
|||
"""ONNX-based configuration.
|
||||
"""
|
||||
|
||||
from collections import OrderedDict
|
||||
import copy
|
||||
from collections import OrderedDict
|
||||
from typing import Mapping, Optional
|
||||
|
||||
import torch
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
class OnnxConfig:
|
||||
""""""
|
||||
"""Implements the base ONNX configuration."""
|
||||
|
||||
DEFAULT_BATCH_SIZE = 2
|
||||
DEFAULT_SEQ_LEN = 8
|
||||
DEFAULT_TASK_OUTPUTS = {
|
||||
"causal-lm": OrderedDict({"probs": {0: "batch_size"}})
|
||||
}
|
||||
DEFAULT_TASK_OUTPUTS = {"causal-lm": OrderedDict({"probs": {0: "batch_size"}})}
|
||||
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
task: Optional[str] = "causal-lm",
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
task: Optional[str] = "causal-lm",
|
||||
) -> None:
|
||||
""""""
|
||||
"""Initializes by verifying whether `task` is supported.
|
||||
|
||||
Args:
|
||||
config: Model's configuration.
|
||||
task: Type of task that the exported model will be used.
|
||||
|
||||
"""
|
||||
|
||||
assert task in self.DEFAULT_TASK_OUTPUTS.keys(), f"`task`: {task} is not supported yet."
|
||||
|
||||
|
@ -34,38 +39,79 @@ class OnnxConfig:
|
|||
|
||||
@property
|
||||
def batch_size(self) -> int:
|
||||
""""""
|
||||
"""Batch size.
|
||||
|
||||
Returns:
|
||||
(int): Default batch size.
|
||||
|
||||
"""
|
||||
|
||||
return self.DEFAULT_BATCH_SIZE
|
||||
|
||||
@property
|
||||
def seq_len(self) -> int:
|
||||
""""""
|
||||
"""Sequence length.
|
||||
|
||||
Returns:
|
||||
(int): Default sequence length.
|
||||
|
||||
"""
|
||||
|
||||
return self.DEFAULT_SEQ_LEN
|
||||
|
||||
@property
|
||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
""""""
|
||||
"""ONNX-based inputs structure.
|
||||
|
||||
Returns:
|
||||
(Mapping[str, Mapping[int, str]]): ONNX-based inputs.
|
||||
|
||||
"""
|
||||
|
||||
return OrderedDict({"input_ids": {0: "batch_size", 1: "seq_len"}})
|
||||
|
||||
@property
|
||||
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
""""""
|
||||
"""ONNX-based outputs structure.
|
||||
|
||||
Returns:
|
||||
(Mapping[str, Mapping[int, str]]): ONNX-based outputs.
|
||||
|
||||
"""
|
||||
|
||||
return copy.deepcopy(self.DEFAULT_TASK_OUTPUTS[self.task])
|
||||
|
||||
def generate_dummy_inputs(self):
|
||||
""""""
|
||||
|
||||
def generate_dummy_inputs(self) -> Mapping[str, torch.Tensor]:
|
||||
"""Generates dummy inputs for the ONNX exporter.
|
||||
|
||||
Returns:
|
||||
(Mapping[str, Any]): Keyword arguments for the model's forward.
|
||||
|
||||
"""
|
||||
|
||||
return {"input_ids": torch.zeros((self.batch_size, self.seq_len), dtype=torch.long)}
|
||||
|
||||
|
||||
class OnnxConfigWithPast(OnnxConfig):
|
||||
""""""
|
||||
"""Implements the base ONNX configuration with support for past key/values."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
task: Optional[str] = "causal-lm",
|
||||
use_past: Optional[bool] = False,
|
||||
past_key_values: Optional[int] = 2,
|
||||
) -> None:
|
||||
"""Overrides initialization and defines whether past key/values are used.
|
||||
|
||||
Args:
|
||||
config: Model's configuration.
|
||||
task: Type of task that the exported model will be used.
|
||||
use_past: Whether past key/values (`use_cache`) should be used.
|
||||
past_key_values: Number of past-related information (2 for key and values).
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config, task: Optional[str] = "causal-lm", use_past: Optional[bool] = False, past_key_values: Optional[int] = 2) -> None:
|
||||
super().__init__(config, task=task)
|
||||
|
||||
if use_past:
|
||||
|
@ -78,6 +124,13 @@ class OnnxConfigWithPast(OnnxConfig):
|
|||
|
||||
@property
|
||||
def hidden_size(self) -> int:
|
||||
"""Dimensionality of hidden units.
|
||||
|
||||
Returns:
|
||||
(int): Hidden units size.
|
||||
|
||||
"""
|
||||
|
||||
if not hasattr(self.config, "hidden_size"):
|
||||
raise AttributeError()
|
||||
|
||||
|
@ -85,6 +138,13 @@ class OnnxConfigWithPast(OnnxConfig):
|
|||
|
||||
@property
|
||||
def num_layers(self) -> int:
|
||||
"""Number of layers.
|
||||
|
||||
Returns:
|
||||
(int): Number of layers.
|
||||
|
||||
"""
|
||||
|
||||
if not hasattr(self.config, "num_layers"):
|
||||
raise AttributeError()
|
||||
|
||||
|
@ -92,6 +152,13 @@ class OnnxConfigWithPast(OnnxConfig):
|
|||
|
||||
@property
|
||||
def num_attention_heads(self) -> int:
|
||||
"""Number of attention heads.
|
||||
|
||||
Returns:
|
||||
(int): Number of attention heads.
|
||||
|
||||
"""
|
||||
|
||||
if not hasattr(self.config, "num_attention_heads"):
|
||||
raise AttributeError()
|
||||
|
||||
|
@ -99,6 +166,13 @@ class OnnxConfigWithPast(OnnxConfig):
|
|||
|
||||
@property
|
||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
"""ONNX-based inputs structure.
|
||||
|
||||
Returns:
|
||||
(Mapping[str, Mapping[int, str]]): ONNX-based inputs.
|
||||
|
||||
"""
|
||||
|
||||
inputs = super().inputs
|
||||
|
||||
if self.use_past:
|
||||
|
@ -110,6 +184,13 @@ class OnnxConfigWithPast(OnnxConfig):
|
|||
|
||||
@property
|
||||
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
"""ONNX-based outputs structure.
|
||||
|
||||
Returns:
|
||||
(Mapping[str, Mapping[int, str]]): ONNX-based outputs.
|
||||
|
||||
"""
|
||||
|
||||
outputs = super().outputs
|
||||
|
||||
if self.use_past:
|
||||
|
@ -120,8 +201,13 @@ class OnnxConfigWithPast(OnnxConfig):
|
|||
|
||||
return outputs
|
||||
|
||||
def generate_dummy_inputs(self):
|
||||
""""""
|
||||
def generate_dummy_inputs(self) -> Mapping[str, torch.Tensor]:
|
||||
"""Generates dummy inputs for the ONNX exporter.
|
||||
|
||||
Returns:
|
||||
(Mapping[str, Any]): Keyword arguments for the model's forward.
|
||||
|
||||
"""
|
||||
|
||||
dummy_inputs = super().generate_dummy_inputs()
|
||||
|
||||
|
@ -129,7 +215,13 @@ class OnnxConfigWithPast(OnnxConfig):
|
|||
# [past_key_values, batch_size, n_head, past_seq_len, d_head]
|
||||
dummy_inputs["past_key_values"] = tuple(
|
||||
[
|
||||
torch.zeros(self.config.past_key_values, self.batch_size, self.num_attention_heads, self.seq_len, self.hidden_size // self.num_attention_heads)
|
||||
torch.zeros(
|
||||
self.config.past_key_values,
|
||||
self.batch_size,
|
||||
self.num_attention_heads,
|
||||
self.seq_len,
|
||||
self.hidden_size // self.num_attention_heads,
|
||||
)
|
||||
for _ in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
|
|
|
@ -49,9 +49,9 @@ def validate_onnx_outputs(
|
|||
|
||||
ref_inputs = config.generate_dummy_inputs()
|
||||
ref_outputs = reference_model(**ref_inputs)
|
||||
ref_outputs_dict = {}
|
||||
|
||||
|
||||
# Flattens the reference outputs
|
||||
ref_outputs_dict = {}
|
||||
for name, value in ref_outputs.items():
|
||||
if name == "past_key_values":
|
||||
name = "present"
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
"""ONNX-compliant forward functions.
|
||||
"""
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
@ -14,7 +14,7 @@ def gpt2_onnx_forward(
|
|||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
past_key_values: Optional[Tuple[torch.FloatTensor, ...]] = None,
|
||||
) -> Tuple[torch.FloatTensor, ...]:
|
||||
) -> Dict[str, torch.FloatTensor]:
|
||||
"""Overrides the GPT-2 forward by returning probabilities and past key/values.
|
||||
|
||||
Args:
|
||||
|
@ -22,7 +22,7 @@ def gpt2_onnx_forward(
|
|||
past_key_values: Past pre-computed key/values tensor.
|
||||
|
||||
Returns:
|
||||
(Tuple[torch.FloatTensor, ...]): Output probabilities and past key/values.
|
||||
(Dict[str, torch.FloatTensor]): Output probabilities and past key/values.
|
||||
|
||||
"""
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ def optimize_onnx(
|
|||
onnx_model_path: str,
|
||||
model_config: PretrainedConfig,
|
||||
use_gpu: Optional[bool] = False,
|
||||
opt_level: Optional[int] = 0,
|
||||
opt_level: Optional[int] = 1,
|
||||
only_ort: Optional[bool] = False,
|
||||
float16: Optional[bool] = False,
|
||||
input_int32: Optional[bool] = False,
|
||||
|
@ -91,8 +91,8 @@ def optimize_onnx(
|
|||
optimizer_args = (ort_model,)
|
||||
if model_type in ["gpt2", "gpt2-flex"]:
|
||||
optimizer_args += (model_config.num_attention_heads, model_config.hidden_size)
|
||||
|
||||
optimizer = onnx_opt_model(*optimizer_args)
|
||||
|
||||
options = FusionOptions(model_type)
|
||||
|
||||
optimizer.optimize(options)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
"""Provides post-exported ONNX optimization classes and methods.
|
||||
"""ONNX-related optimization utilities.
|
||||
"""
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
"""ONNX optimization models.
|
||||
"""Transformer-XL ONNX optimization model.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Tuple
|
Загрузка…
Ссылка в новой задаче