chore(nlp): Finishes docstringing ONNX-related files.

This commit is contained in:
Gustavo Rosa 2022-11-08 05:03:58 -08:00
Родитель 553033173e
Коммит ab39978eb3
8 изменённых файлов: 138 добавлений и 33 удалений

Просмотреть файл

@ -0,0 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""ONNX-related configuration utilities.
"""

Просмотреть файл

@ -4,13 +4,21 @@
"""GPT-2 ONNX configuration.
"""
from typing import Optional
from transformers.configuration_utils import PretrainedConfig
from archai.nlp.onnx.config_utils.onnx_config_base import OnnxConfigWithPast
class GPT2OnnxConfig(OnnxConfigWithPast):
def __init__(self, config, task="causal-lm", use_past=False) -> None:
"""Implements a GPT-2 ONNX configuration (with past key/values support)."""
def __init__(
self, config: PretrainedConfig, task: Optional[str] = "causal-lm", use_past: Optional[bool] = False
) -> None:
super().__init__(config, task=task, use_past=use_past, past_key_values=2)
@property
def num_layers(self) -> int:
return self.config.n_layer
return self.config.n_layer

Просмотреть файл

@ -4,28 +4,33 @@
"""ONNX-based configuration.
"""
from collections import OrderedDict
import copy
from collections import OrderedDict
from typing import Mapping, Optional
import torch
from transformers.configuration_utils import PretrainedConfig
class OnnxConfig:
""""""
"""Implements the base ONNX configuration."""
DEFAULT_BATCH_SIZE = 2
DEFAULT_SEQ_LEN = 8
DEFAULT_TASK_OUTPUTS = {
"causal-lm": OrderedDict({"probs": {0: "batch_size"}})
}
DEFAULT_TASK_OUTPUTS = {"causal-lm": OrderedDict({"probs": {0: "batch_size"}})}
def __init__(self,
config,
task: Optional[str] = "causal-lm",
def __init__(
self,
config: PretrainedConfig,
task: Optional[str] = "causal-lm",
) -> None:
""""""
"""Initializes by verifying whether `task` is supported.
Args:
config: Model's configuration.
task: Type of task that the exported model will be used.
"""
assert task in self.DEFAULT_TASK_OUTPUTS.keys(), f"`task`: {task} is not supported yet."
@ -34,38 +39,79 @@ class OnnxConfig:
@property
def batch_size(self) -> int:
""""""
"""Batch size.
Returns:
(int): Default batch size.
"""
return self.DEFAULT_BATCH_SIZE
@property
def seq_len(self) -> int:
""""""
"""Sequence length.
Returns:
(int): Default sequence length.
"""
return self.DEFAULT_SEQ_LEN
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
""""""
"""ONNX-based inputs structure.
Returns:
(Mapping[str, Mapping[int, str]]): ONNX-based inputs.
"""
return OrderedDict({"input_ids": {0: "batch_size", 1: "seq_len"}})
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
""""""
"""ONNX-based outputs structure.
Returns:
(Mapping[str, Mapping[int, str]]): ONNX-based outputs.
"""
return copy.deepcopy(self.DEFAULT_TASK_OUTPUTS[self.task])
def generate_dummy_inputs(self):
""""""
def generate_dummy_inputs(self) -> Mapping[str, torch.Tensor]:
"""Generates dummy inputs for the ONNX exporter.
Returns:
(Mapping[str, Any]): Keyword arguments for the model's forward.
"""
return {"input_ids": torch.zeros((self.batch_size, self.seq_len), dtype=torch.long)}
class OnnxConfigWithPast(OnnxConfig):
""""""
"""Implements the base ONNX configuration with support for past key/values."""
def __init__(
self,
config: PretrainedConfig,
task: Optional[str] = "causal-lm",
use_past: Optional[bool] = False,
past_key_values: Optional[int] = 2,
) -> None:
"""Overrides initialization and defines whether past key/values are used.
Args:
config: Model's configuration.
task: Type of task that the exported model will be used.
use_past: Whether past key/values (`use_cache`) should be used.
past_key_values: Number of past-related information (2 for key and values).
"""
def __init__(self, config, task: Optional[str] = "causal-lm", use_past: Optional[bool] = False, past_key_values: Optional[int] = 2) -> None:
super().__init__(config, task=task)
if use_past:
@ -78,6 +124,13 @@ class OnnxConfigWithPast(OnnxConfig):
@property
def hidden_size(self) -> int:
"""Dimensionality of hidden units.
Returns:
(int): Hidden units size.
"""
if not hasattr(self.config, "hidden_size"):
raise AttributeError()
@ -85,6 +138,13 @@ class OnnxConfigWithPast(OnnxConfig):
@property
def num_layers(self) -> int:
"""Number of layers.
Returns:
(int): Number of layers.
"""
if not hasattr(self.config, "num_layers"):
raise AttributeError()
@ -92,6 +152,13 @@ class OnnxConfigWithPast(OnnxConfig):
@property
def num_attention_heads(self) -> int:
"""Number of attention heads.
Returns:
(int): Number of attention heads.
"""
if not hasattr(self.config, "num_attention_heads"):
raise AttributeError()
@ -99,6 +166,13 @@ class OnnxConfigWithPast(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
"""ONNX-based inputs structure.
Returns:
(Mapping[str, Mapping[int, str]]): ONNX-based inputs.
"""
inputs = super().inputs
if self.use_past:
@ -110,6 +184,13 @@ class OnnxConfigWithPast(OnnxConfig):
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
"""ONNX-based outputs structure.
Returns:
(Mapping[str, Mapping[int, str]]): ONNX-based outputs.
"""
outputs = super().outputs
if self.use_past:
@ -120,8 +201,13 @@ class OnnxConfigWithPast(OnnxConfig):
return outputs
def generate_dummy_inputs(self):
""""""
def generate_dummy_inputs(self) -> Mapping[str, torch.Tensor]:
"""Generates dummy inputs for the ONNX exporter.
Returns:
(Mapping[str, Any]): Keyword arguments for the model's forward.
"""
dummy_inputs = super().generate_dummy_inputs()
@ -129,7 +215,13 @@ class OnnxConfigWithPast(OnnxConfig):
# [past_key_values, batch_size, n_head, past_seq_len, d_head]
dummy_inputs["past_key_values"] = tuple(
[
torch.zeros(self.config.past_key_values, self.batch_size, self.num_attention_heads, self.seq_len, self.hidden_size // self.num_attention_heads)
torch.zeros(
self.config.past_key_values,
self.batch_size,
self.num_attention_heads,
self.seq_len,
self.hidden_size // self.num_attention_heads,
)
for _ in range(self.num_layers)
]
)

Просмотреть файл

@ -49,9 +49,9 @@ def validate_onnx_outputs(
ref_inputs = config.generate_dummy_inputs()
ref_outputs = reference_model(**ref_inputs)
ref_outputs_dict = {}
# Flattens the reference outputs
ref_outputs_dict = {}
for name, value in ref_outputs.items():
if name == "past_key_values":
name = "present"

Просмотреть файл

@ -4,7 +4,7 @@
"""ONNX-compliant forward functions.
"""
from typing import Optional, Tuple
from typing import Dict, Optional, Tuple
import torch
import torch.nn.functional as F
@ -14,7 +14,7 @@ def gpt2_onnx_forward(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[Tuple[torch.FloatTensor, ...]] = None,
) -> Tuple[torch.FloatTensor, ...]:
) -> Dict[str, torch.FloatTensor]:
"""Overrides the GPT-2 forward by returning probabilities and past key/values.
Args:
@ -22,7 +22,7 @@ def gpt2_onnx_forward(
past_key_values: Past pre-computed key/values tensor.
Returns:
(Tuple[torch.FloatTensor, ...]): Output probabilities and past key/values.
(Dict[str, torch.FloatTensor]): Output probabilities and past key/values.
"""

Просмотреть файл

@ -26,7 +26,7 @@ def optimize_onnx(
onnx_model_path: str,
model_config: PretrainedConfig,
use_gpu: Optional[bool] = False,
opt_level: Optional[int] = 0,
opt_level: Optional[int] = 1,
only_ort: Optional[bool] = False,
float16: Optional[bool] = False,
input_int32: Optional[bool] = False,
@ -91,8 +91,8 @@ def optimize_onnx(
optimizer_args = (ort_model,)
if model_type in ["gpt2", "gpt2-flex"]:
optimizer_args += (model_config.num_attention_heads, model_config.hidden_size)
optimizer = onnx_opt_model(*optimizer_args)
options = FusionOptions(model_type)
optimizer.optimize(options)

Просмотреть файл

@ -1,5 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Provides post-exported ONNX optimization classes and methods.
"""ONNX-related optimization utilities.
"""

Просмотреть файл

@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""ONNX optimization models.
"""Transformer-XL ONNX optimization model.
"""
from typing import List, Optional, Tuple