[ORTModule] ATen Efficient Attention and Triton Flash Attention (#17959)

This PR is to support efficient attention and flash attention in
ORTModule, including:
- Use ATen to call efficient attention, which requires PyTorch 2.2.0 dev
or newer. ORTMODULE_USE_EFFICIENT_ATTENTION=1 to enable.
- Integrate Triton Flash attention, which requires
triton==2.0.0.dev20221202. Need A100 or H100.
ORTMODULE_USE_FLASH_ATTENTION=1 to enable.
- A python transformer tool to match sub-graph by config and write
transformer quickly.

Current transformers supports attention mask for both efficient attn and
flash attn, and dropout for efficient attn only. To support more
training scenarios (such as causal mask in GPT2), more transformers need
to be added.

The feature is guarded by system environment variables, it won't effect
any current behavior if not enabled. Since it requires specific
PyTorch/Triton versions, related tests is not added for now.
This commit is contained in:
Vincent Wang 2023-10-27 10:29:27 +08:00 коммит произвёл GitHub
Родитель 37873be86d
Коммит b7408f7389
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
26 изменённых файлов: 2037 добавлений и 93 удалений

Просмотреть файл

@ -387,6 +387,9 @@ if (onnxruntime_ENABLE_TRAINING)
file(GLOB onnxruntime_python_ortmodule_torch_cpp_ext_fused_ops_srcs CONFIGURE_DEPENDS
"${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/*"
)
file(GLOB onnxruntime_python_ortmodule_graph_optimizers_srcs CONFIGURE_DEPENDS
"${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/graph_optimizers/*"
)
file(GLOB onnxruntime_python_ort_triton_srcs CONFIGURE_DEPENDS
"${ORTTRAINING_SOURCE_DIR}/python/training/ort_triton/*.py"
)
@ -741,6 +744,7 @@ if (onnxruntime_ENABLE_TRAINING)
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/cuda/torch_gpu_allocator
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/cuda/fused_ops
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/graph_optimizers
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ort_triton
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ort_triton/kernel
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/utils
@ -794,6 +798,9 @@ if (onnxruntime_ENABLE_TRAINING)
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_ortmodule_torch_cpp_ext_fused_ops_srcs}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_ortmodule_graph_optimizers_srcs}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/graph_optimizers/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_ort_triton_srcs}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ort_triton/

Просмотреть файл

@ -32,8 +32,10 @@ Status ATen::Compute(OpKernelContext* p_ctx) const {
aten_ops::ATenOperatorExecutor::Instance()(op_name_, overload_name_, input_size, dlpack_inputs.get(), output_size,
dlpack_outputs.get());
for (size_t i = 0; i < output_size; ++i) {
ORT_RETURN_IF_ERROR(
p_ctx_internal->SetOutputMLValue(static_cast<int>(i), dlpack::DlpackToOrtValue(dlpack_outputs[i])));
if (dlpack_outputs[i]) {
ORT_RETURN_IF_ERROR(
p_ctx_internal->SetOutputMLValue(static_cast<int>(i), dlpack::DlpackToOrtValue(dlpack_outputs[i])));
}
}
return Status::OK();

Просмотреть файл

@ -10,7 +10,7 @@ namespace onnxruntime {
namespace contrib {
namespace aten_ops {
typedef bool (*IsTensorArgumentFunc)(const char* op_name, const char* overload_name, size_t index);
typedef bool (*IsCpuArgumentFunc)(const char* op_name, const char* overload_name, size_t index, bool is_input);
typedef void (*ExecuteATenOperatorFunc)(const char* op_name, const char* overload_name, size_t input_size,
DLManagedTensor** dlpack_inputs, size_t output_size,
DLManagedTensor** dlpack_outputs);
@ -22,17 +22,17 @@ class ATenOperatorExecutor {
return instance;
}
void Initialize(void* p_is_tensor_argument_func_raw, void* p_execute_aten_op_func_raw) {
ORT_ENFORCE(p_is_tensor_argument_func_raw && p_execute_aten_op_func_raw);
p_is_tensor_argument_func_ = reinterpret_cast<IsTensorArgumentFunc>(p_is_tensor_argument_func_raw);
void Initialize(void* p_is_cpu_argument_func_raw, void* p_execute_aten_op_func_raw) {
ORT_ENFORCE(p_is_cpu_argument_func_raw && p_execute_aten_op_func_raw);
p_is_cpu_argument_func_ = reinterpret_cast<IsCpuArgumentFunc>(p_is_cpu_argument_func_raw);
p_execute_aten_op_func_ = reinterpret_cast<ExecuteATenOperatorFunc>(p_execute_aten_op_func_raw);
}
bool IsInitialized() { return p_execute_aten_op_func_ != nullptr; }
bool IsTensorArgument(const std::string& op_name, const std::string& overload_name, size_t index) {
ORT_ENFORCE(p_is_tensor_argument_func_, "ATenOperatorExecutor is not initialized.");
return p_is_tensor_argument_func_(op_name.c_str(), overload_name.c_str(), index);
bool IsCpuArgument(const std::string& op_name, const std::string& overload_name, size_t index, bool is_input) {
ORT_ENFORCE(p_is_cpu_argument_func_, "ATenOperatorExecutor is not initialized.");
return p_is_cpu_argument_func_(op_name.c_str(), overload_name.c_str(), index, is_input);
}
void operator()(const std::string& op_name, const std::string& overload_name, size_t input_size,
@ -43,7 +43,7 @@ class ATenOperatorExecutor {
}
private:
IsTensorArgumentFunc p_is_tensor_argument_func_ = nullptr;
IsCpuArgumentFunc p_is_cpu_argument_func_ = nullptr;
ExecuteATenOperatorFunc p_execute_aten_op_func_ = nullptr;
};

Просмотреть файл

@ -9,6 +9,7 @@
#include "onnx/defs/data_type_utils.h"
#include "core/framework/op_kernel.h"
#include "core/framework/utils.h"
using namespace ONNX_NAMESPACE::Utils;
@ -77,7 +78,7 @@ std::unordered_set<NodeIndex> GetCpuPreferredNodes(const onnxruntime::GraphViewe
ORT_THROW_IF_ERROR(node->ForEachWithIndex(
node->OutputDefs(),
[&](const NodeArg& node_arg, size_t out_index) {
if (kernel_info->kernel_def->IsOutputOnCpu(out_index)) {
if (utils::IsOutputOnCpu(*node, kernel_info, out_index)) {
cpu_output_args.insert(&node_arg);
auto consumer_nodes = graph.GetConsumerNodes(node_arg.Name());
for (auto& consumer_node : consumer_nodes) {

Просмотреть файл

@ -1025,7 +1025,32 @@ bool IsInputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index)
overload_name = attrs.at("overload_name").s();
}
return !contrib::aten_ops::ATenOperatorExecutor::Instance().IsTensorArgument(op_name, overload_name, index);
return contrib::aten_ops::ATenOperatorExecutor::Instance().IsCpuArgument(op_name, overload_name, index, true);
}
#else
ORT_UNUSED_PARAMETER(node);
#endif
return false;
}
bool IsOutputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index) {
if (p_kci && p_kci->kernel_def->IsOutputOnCpu(index)) {
return true;
}
#ifdef ENABLE_ATEN
if (node.GetExecutionProviderType() == kCudaExecutionProvider && node.OpType() == "ATen" &&
node.Domain() == kPytorchAtenDomain) {
const auto& attrs = node.GetAttributes();
ORT_ENFORCE(utils::HasString(attrs.at("operator")));
std::string op_name = attrs.at("operator").s();
std::string overload_name = "";
if (attrs.find("overload_name") != attrs.end() && utils::HasString(attrs.at("overload_name"))) {
overload_name = attrs.at("overload_name").s();
}
return contrib::aten_ops::ATenOperatorExecutor::Instance().IsCpuArgument(op_name, overload_name, index, false);
}
#else
ORT_UNUSED_PARAMETER(node);

Просмотреть файл

@ -121,6 +121,7 @@ common::Status ExecuteSubgraph(const SessionState& session_state, const FeedsFet
bool sync_subgraph_fetches = false);
bool IsInputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index);
bool IsOutputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index);
template <typename T>
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() {

Просмотреть файл

@ -249,7 +249,7 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelReg
if (!arg->Exists())
continue;
if (kci && kci->kernel_def->IsOutputOnCpu(i))
if (utils::IsOutputOnCpu(node, kci, i))
non_provider_output_defs_.insert(arg);
else
provider_output_defs_.insert(arg);
@ -308,7 +308,7 @@ void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg, co
if (!kci || !utils::IsInputOnCpu(it, kci, arg_input_index)) provider_input_nodes_[arg].insert(&it);
}
if (arg_output_index != -1) {
if (!kci || !kci->kernel_def->IsOutputOnCpu(arg_output_index)) provider_output_nodes_[arg].insert(&it);
if (!kci || !utils::IsOutputOnCpu(it, kci, arg_output_index)) provider_output_nodes_[arg].insert(&it);
}
}
}
@ -404,8 +404,8 @@ bool TransformerMemcpyImpl::ProcessInitializers(const KernelRegistryManager& ker
// normally initializers are only inputs, but things may change with ops like assign
ORT_THROW_IF_ERROR(Node::ForEachWithIndex(
p_node->OutputDefs(),
[kci, &dup_replacements](const onnxruntime::NodeArg& arg, size_t index) {
if (kci->kernel_def->IsOutputOnCpu(index)) {
[kci, &p_node, &dup_replacements](const onnxruntime::NodeArg& arg, size_t index) {
if (utils::IsOutputOnCpu(*p_node, kci, index)) {
ORT_ENFORCE(dup_replacements.find(&arg) == dup_replacements.end());
}
return Status::OK();

Просмотреть файл

@ -1214,14 +1214,14 @@ void addGlobalMethods(py::module& m) {
#ifdef ENABLE_ATEN
m.def("register_aten_op_executor",
[](const std::string& is_tensor_argument_address_str, const std::string& aten_op_executor_address_str) -> void {
size_t is_tensor_argument_address_int, aten_op_executor_address_int;
[](const std::string& is_cpu_argument_address_str, const std::string& aten_op_executor_address_str) -> void {
size_t is_cpu_argument_address_int, aten_op_executor_address_int;
ORT_THROW_IF_ERROR(
ParseStringWithClassicLocale(is_tensor_argument_address_str, is_tensor_argument_address_int));
ParseStringWithClassicLocale(is_cpu_argument_address_str, is_cpu_argument_address_int));
ORT_THROW_IF_ERROR(ParseStringWithClassicLocale(aten_op_executor_address_str, aten_op_executor_address_int));
void* p_is_tensor_argument = reinterpret_cast<void*>(is_tensor_argument_address_int);
void* p_is_cpu_argument = reinterpret_cast<void*>(is_cpu_argument_address_int);
void* p_aten_op_executor = reinterpret_cast<void*>(aten_op_executor_address_int);
contrib::aten_ops::ATenOperatorExecutor::Instance().Initialize(p_is_tensor_argument, p_aten_op_executor);
contrib::aten_ops::ATenOperatorExecutor::Instance().Initialize(p_is_cpu_argument, p_aten_op_executor);
});
#endif
}

Просмотреть файл

@ -29,5 +29,5 @@ def load_aten_op_executor_cpp_extension():
from onnxruntime.training.ortmodule.torch_cpp_extensions import aten_op_executor
_C.register_aten_op_executor(
str(aten_op_executor.is_tensor_argument_address()), str(aten_op_executor.execute_aten_operator_address())
str(aten_op_executor.is_cpu_argument_address()), str(aten_op_executor.execute_aten_operator_address())
)

Просмотреть файл

@ -154,11 +154,32 @@ class ATenOperatorCache {
std::unordered_map<std::pair<std::string, std::string>, ATenOperator, PairHash> ops_;
};
// Backend uses this function to check if an argument is CPU input (non-tensor argument) or not.
bool IsTensorArgument(const char* op_name, const char* overload_name, size_t index) {
const auto& aten_op = ATenOperatorCache::Instance().GetOperator(op_name, overload_name);
TORCH_INTERNAL_ASSERT(index < aten_op.argument_size);
return aten_op.elem_kinds[index] == c10::TypeKind::TensorType;
const std::unordered_map<std::string, std::unordered_set<size_t>> kCpuTensorInputsMap = {
{"_efficient_attention_forward", {4, 5, 11, 12}}, {"_efficient_attention_backward", {6, 7, 12, 13}}};
const std::unordered_map<std::string, std::unordered_set<size_t>> kCpuTensorOutputsMap = {
{"_efficient_attention_forward", {2, 3}}};
// Backend uses this function to check if an argument is CPU input or not.
bool IsCpuArgument(const char* op_name, const char* overload_name, size_t index, bool is_input) {
if (is_input) {
// If the argument is non-tensor type, it's CPU argument.
const auto& aten_op = ATenOperatorCache::Instance().GetOperator(op_name, overload_name);
TORCH_INTERNAL_ASSERT(index < aten_op.argument_size);
if (aten_op.elem_kinds[index] != c10::TypeKind::TensorType) {
return true;
}
}
std::string full_name = std::string(op_name);
std::string overload_name_str = std::string(overload_name);
if (overload_name_str != "") {
full_name += ("." + overload_name_str);
}
const auto& cpu_tensors_map = is_input ? kCpuTensorInputsMap : kCpuTensorOutputsMap;
return cpu_tensors_map.find(full_name) != cpu_tensors_map.end() &&
cpu_tensors_map.at(full_name).find(index) != cpu_tensors_map.at(full_name).end();
}
void ExecuteATenOperator(const char* op_name, const char* overload_name, size_t input_size,
@ -196,14 +217,15 @@ void ExecuteATenOperator(const char* op_name, const char* overload_name, size_t
size_t output_index = 0;
for (const auto& ret : torch::jit::pop(stack, output_size)) {
const auto& tensor = ret.toTensor();
dlpack_outputs[output_index++] = at::toDLPack(tensor.is_contiguous() ? tensor : tensor.contiguous());
dlpack_outputs[output_index++] =
tensor.defined() ? at::toDLPack(tensor.is_contiguous() ? tensor : tensor.contiguous()) : nullptr;
}
}
size_t is_tensor_argument_address() { return reinterpret_cast<size_t>(&IsTensorArgument); }
size_t is_cpu_argument_address() { return reinterpret_cast<size_t>(&IsCpuArgument); }
size_t execute_aten_operator_address() { return reinterpret_cast<size_t>(&ExecuteATenOperator); }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("is_tensor_argument_address", &is_tensor_argument_address, "Address of tensor argument check.");
m.def("is_cpu_argument_address", &is_cpu_argument_address, "Address of tensor argument check.");
m.def("execute_aten_operator_address", &execute_aten_operator_address, "Address of Aten operator executor");
}

Просмотреть файл

@ -5,7 +5,7 @@ import torch # noqa: F401
from onnxruntime.capi import _pybind_state as _C
from .aten_op_executor import execute_aten_operator_address, is_tensor_argument_address
from .aten_op_executor import execute_aten_operator_address, is_cpu_argument_address
def run_once_aten_op_executor(f):
@ -30,7 +30,7 @@ def run_once_aten_op_executor(f):
@run_once_aten_op_executor
def load_aten_op_executor_cpp_extension():
_C.register_aten_op_executor(str(is_tensor_argument_address()), str(execute_aten_operator_address()))
_C.register_aten_op_executor(str(is_cpu_argument_address()), str(execute_aten_operator_address()))
def init_aten_op_executor():

Просмотреть файл

@ -4180,6 +4180,7 @@ Return true if all elements are true and false otherwise.
.Attr("func_name", "Function name of the Python Triton kernel.", AttributeProto::STRING, std::string(""))
.Attr("onnx_key", "The hash key for the ONNX graph.", AttributeProto::INT, static_cast<int64_t>(0))
.Attr("onnx_string", "The onnx string of the triton kernel.", AttributeProto::STRING, std::string(""))
.AllowUncheckedAttributes()
.Input(0, "inputs",
"Input tensors. If to call an existing Python Triton kernel, "
"the input count and order should match the arguments of the function. If to compute an ONNX graph, "

Просмотреть файл

@ -3,15 +3,28 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from ._mm import triton_gemm, triton_gemm_out, triton_matmul, triton_matmul_out
from ._slice_scel import slice_scel, slice_scel_backward, transform_slice_scel
import os
__all__ = [
from ._mm import triton_gemm, triton_gemm_out, triton_matmul, triton_matmul_out # noqa: F401
from ._slice_scel import optimize_graph_for_slice_scel, slice_scel, slice_scel_backward # noqa: F401
_all_kernels = [
"triton_gemm",
"triton_gemm_out",
"triton_matmul",
"triton_matmul_out",
"slice_scel",
"slice_scel_backward",
"transform_slice_scel",
]
_all_optimizers = [
"optimize_graph_for_slice_scel",
]
if "ORTMODULE_USE_FLASH_ATTENTION" in os.environ and int(os.getenv("ORTMODULE_USE_FLASH_ATTENTION")) == 1:
from ._flash_attn import flash_attn_backward, flash_attn_forward, optimize_graph_for_flash_attention # noqa: F401
_all_kernels.extend(["flash_attn_forward", "flash_attn_backward"])
_all_optimizers.append("optimize_graph_for_flash_attention")
__all__ = _all_kernels + _all_optimizers # noqa: PLE0605

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -11,7 +11,7 @@ import triton
import triton.language as tl
from onnx import TensorProto, helper
from onnxruntime.training.ortmodule import register_graph_transformer
from onnxruntime.training.ortmodule import register_graph_optimizer
from .._utils import get_attribute, to_numpy_array
@ -246,8 +246,8 @@ def _get_shape_related_nodes(graph, start_arg, sub_graph_nodes):
args.append(output)
@register_graph_transformer(devices="cuda")
def transform_slice_scel(graph):
@register_graph_optimizer(devices="cuda")
def optimize_graph_for_slice_scel(graph):
remove_nodes = []
triton_nodes = []
value_infos = []

Просмотреть файл

@ -124,7 +124,8 @@ def _are_deterministic_algorithms_enabled():
return ORTMODULE_IS_DETERMINISTIC
from .graph_transformer_registry import register_graph_transformer # noqa: E402, F401
from .graph_optimizer_registry import register_graph_optimizer # noqa: E402, F401
from .graph_optimizers import * # noqa: E402, F403
from .options import DebugOptions, LogLevel # noqa: E402, F401
# ORTModule must be loaded only after all validation passes

Просмотреть файл

@ -21,7 +21,7 @@ from ._io import _FlattenedModule, _InputInfo, unflatten_user_output
from ._logger import ORTModuleInitPhase, SuppressLogs, TrackTime
from ._runtime_inspector import Phase
from ._utils import save_tuning_results, set_tuning_results
from .graph_transformer_registry import GraphTransformerRegistry
from .graph_optimizer_registry import GraphOptimizerRegistry
from .options import DebugOptions, _SkipCheck
@ -369,7 +369,7 @@ class TrainingManager(GraphExecutionManager):
device_type = self._device.type
if device_type == "cuda" and self.is_rocm_pytorch:
device_type = "rocm"
GraphTransformerRegistry.transform_all(
GraphOptimizerRegistry.optimize_all(
type(self._flattened_module._original_module).__name__, device_type, self._onnx_models.optimized_model.graph
)

Просмотреть файл

@ -0,0 +1,47 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from typing import Callable
from onnx.onnx_ml_pb2 import GraphProto
class GraphOptimizerRegistry:
_OPTIMIZER_FUNCS = {} # noqa: RUF012
@classmethod
def register(cls, target_modules: str, devices: str, priority: int, fn: Callable[[GraphProto], None]):
modules = []
if target_modules == "all":
modules.append("all")
else:
modules = target_modules.split("|")
for module in modules:
if module in cls._OPTIMIZER_FUNCS:
cls._OPTIMIZER_FUNCS[module].append((fn, devices, priority))
else:
cls._OPTIMIZER_FUNCS[module] = [(fn, devices, priority)]
@classmethod
def optimize_all(cls, module_name: str, device: str, graph: GraphProto):
optimizers_to_apply = []
if "all" in cls._OPTIMIZER_FUNCS:
optimizers_to_apply.extend(cls._OPTIMIZER_FUNCS["all"])
if module_name in cls._OPTIMIZER_FUNCS:
optimizers_to_apply.extend(cls._OPTIMIZER_FUNCS[module_name])
optimizers_to_apply = [x for x in optimizers_to_apply if x[1] == "all" or device in x[1]]
optimizers_to_apply.sort(key=lambda x: x[2], reverse=True)
for fn, _, _ in optimizers_to_apply:
fn(graph)
# target_modules can be multiple module names separated by "|", or "all" means apply to all modules.
# devices can be multiple device types separated by "|" or "all" means apply to all devices.
def register_graph_optimizer(target_modules: str = "all", devices: str = "all", priority: int = 0):
def graph_optimizer_wrapper(fn):
GraphOptimizerRegistry.register(target_modules, devices, priority, fn)
return fn
return graph_optimizer_wrapper

Просмотреть файл

@ -0,0 +1,15 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import os
_all_optimizers = []
if "ORTMODULE_USE_EFFICIENT_ATTENTION" in os.environ and int(os.getenv("ORTMODULE_USE_EFFICIENT_ATTENTION")) == 1:
from ._aten_attn import optimize_graph_for_aten_efficient_attention # noqa: F401
_all_optimizers.append("optimize_graph_for_aten_efficient_attention")
__all__ = _all_optimizers # noqa: PLE0605

Просмотреть файл

@ -0,0 +1,414 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""
PyTorch's _efficient_attention_forward/_efficient_attention_backward APIs is keep changing. Current implementation
is tested well on version 2.2.0.dev20231010+cu121, and should be run well since official version 2.2.0. If may fail to
run is you are using PyTorch with older versions.
PyTorch also has API for flash attention (currently doesn't support random attention mask or Dropout), we can add
support if we want to try in the future.
"""
from typing import List, Tuple
from onnx import GraphProto, NodeProto, TensorProto, helper
from ..graph_optimizer_registry import register_graph_optimizer
from .utils import GraphMatcher, check_attribute_value, make_constant_node, update_graph
def _make_efficient_attention_nodes(
idx: int,
q: str,
k: str,
v: str,
y: str,
dy: str,
dq: str,
dk: str,
dv: str,
bias: str,
expand_bias: bool,
scale: float,
dropout_ratio: float,
causal: bool,
):
nodes_to_add = []
scale_node = make_constant_node("scale_" + str(idx), TensorProto.FLOAT, [], [scale])
dropout_ratio_node = make_constant_node("dropout_ratio_" + str(idx), TensorProto.FLOAT, [], [dropout_ratio])
causal_node = make_constant_node("causal_" + str(idx), TensorProto.INT64, [], [1 if causal else 0])
int_zero_node = make_constant_node("int_zero_" + str(idx), TensorProto.INT64, [], [0])
true_node = make_constant_node("true_" + str(idx), TensorProto.BOOL, [], [True])
false_node = make_constant_node("false_" + str(idx), TensorProto.BOOL, [], [False])
logsumexp = helper.make_tensor_value_info("logsumexp" + str(idx), TensorProto.FLOAT, [])
seed = helper.make_tensor_value_info("seed" + str(idx), TensorProto.INT64, [])
offset = helper.make_tensor_value_info("offset" + str(idx), TensorProto.INT64, [])
new_value_infos = [logsumexp, seed, offset]
if expand_bias:
shape_0 = helper.make_node("Shape", [q], ["shape_0_" + str(idx)], start=0, end=1)
shape_1 = helper.make_node("Shape", [q], ["shape_1_" + str(idx)], start=2, end=3)
shape_2 = helper.make_node("Shape", [q], ["shape_2_" + str(idx)], start=1, end=2)
shape_3 = helper.make_node("Shape", [k], ["shape_3_" + str(idx)], start=1, end=2)
concat = helper.make_node(
"Concat",
["shape_0_" + str(idx), "shape_1_" + str(idx), "shape_2_" + str(idx), "shape_3_" + str(idx)],
["concated_shape_" + str(idx)],
axis=0,
)
expand = helper.make_node("Expand", [bias, "concated_shape_" + str(idx)], ["expanded_bias_" + str(idx)])
nodes_to_add.extend([shape_0, shape_1, shape_2, shape_3, concat, expand])
bias = "expanded_bias_" + str(idx)
fwd_node = helper.make_node(
"ATen",
[
q,
k,
v,
bias,
"",
"",
"",
dropout_ratio_node.output[0],
causal_node.output[0],
true_node.output[0],
scale_node.output[0],
"",
"",
],
[y, logsumexp.name, seed.name, offset.name],
"efficient_attention_forward_" + str(idx),
None,
"org.pytorch.aten",
operator="_efficient_attention_forward",
)
bwd_node = helper.make_node(
"ATen",
[
dy,
q,
k,
v,
bias,
y,
"",
"",
int_zero_node.output[0],
int_zero_node.output[0],
logsumexp.name,
dropout_ratio_node.output[0],
seed.name,
offset.name,
causal_node.output[0],
false_node.output[0],
scale_node.output[0],
"",
],
[dq, dk, dv, ""],
"efficient_attention_backward_" + str(idx),
None,
"org.pytorch.aten",
operator="_efficient_attention_backward",
)
nodes_to_add.extend(
[scale_node, dropout_ratio_node, causal_node, int_zero_node, true_node, false_node, fwd_node, bwd_node]
)
return nodes_to_add, new_value_infos
# Without causal mask, with Dropout. For example, BERT model in HuggingFace.
_PATTERN_0: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [
("MatMul", False, []), # 0
("Transpose", True, [(0, 0, 0)]), # 1
("Transpose", True, [(0, 0, 1)]), # 2
("Div", False, [(0, 0, 0)]), # 3
("Add", False, [(3, 0, 0)]), # 4
("Softmax", False, [(4, 0, 0)]), # 5
("Dropout", False, [(5, 0, 0)]), # 6
("MatMul", False, [(6, 0, 0)]), # 7
("Transpose", True, [(7, 0, 1)]), # 8
("Transpose", False, [(7, 0, 0)]), # 9
("FusedMatMul", False, [(8, 0, 1)]), # 10
("DropoutGrad", False, [(10, 0, 0), (6, 1, 1)]), # 11
("SoftmaxGrad_13", False, [(11, 0, 0), (5, 0, 1)]), # 12
("Identity", False, [(12, 0, 0)]), # 13
("Div", False, [(13, 0, 0)]), # 14
("Identity", False, [(14, 0, 0)]), # 15
("FusedMatMul", False, [(2, 0, 1), (15, 0, 0)]), # 16
("FusedMatMul", False, [(1, 0, 0), (15, 0, 1)]), # 17
("FusedMatMul", False, [(6, 0, 0)]), # 18
("Transpose", True, [(18, 0, 1)]), # 19
("Transpose", False, [(16, 0, 0)]), # 20
("Transpose", False, [(17, 0, 0)]), # 21
("Transpose", False, [(18, 0, 0)]), # 22
]
def _optimize_for_pattern_0(matcher: GraphMatcher, idx: int, nodes: List[NodeProto]):
# Check forward only as the backward is expected to be consistent if it's built correctly.
scale_value = matcher.get_constant_value(nodes[3].input[1])
ratio_value = matcher.get_constant_value(nodes[6].input[1])
if not (
check_attribute_value(nodes[1], "perm", [0, 2, 1, 3])
and check_attribute_value(nodes[2], "perm", [0, 2, 3, 1])
and scale_value is not None
and ratio_value is not None
and check_attribute_value(nodes[8], "perm", [0, 2, 1, 3])
and check_attribute_value(nodes[9], "perm", [0, 2, 1, 3])
):
return [], [], []
_, add_input_shape_0 = matcher.get_type_and_shape(nodes[4].input[0])
_, add_input_shape_1 = matcher.get_type_and_shape(nodes[4].input[1])
nodes_to_add, new_value_infos = _make_efficient_attention_nodes(
idx,
nodes[1].input[0],
nodes[2].input[0],
nodes[8].input[0],
nodes[9].output[0],
nodes[19].input[0],
nodes[20].output[0],
nodes[21].output[0],
nodes[22].output[0],
nodes[4].input[1],
add_input_shape_0 != add_input_shape_1,
1 / float(scale_value[0] if isinstance(scale_value, list) else scale_value),
ratio_value,
False,
)
return nodes, nodes_to_add, new_value_infos
# Without causal mask, without Dropout. For example, BERT model and disabling attention dropout in HuggingFace.
_PATTERN_1: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [
("MatMul", False, []), # 0
("Transpose", True, [(0, 0, 0)]), # 1
("Transpose", True, [(0, 0, 1)]), # 2
("Div", False, [(0, 0, 0)]), # 3
("Add", False, [(3, 0, 0)]), # 4
("Softmax", False, [(4, 0, 0)]), # 5
("MatMul", False, [(5, 0, 0)]), # 6
("Transpose", True, [(6, 0, 1)]), # 7
("Transpose", False, [(6, 0, 0)]), # 8
("FusedMatMul", False, [(7, 0, 1)]), # 9
("SoftmaxGrad_13", False, [(9, 0, 0), (5, 0, 1)]), # 10
("Identity", False, [(10, 0, 0)]), # 11
("Div", False, [(11, 0, 0)]), # 12
("Identity", False, [(12, 0, 0)]), # 13
("FusedMatMul", False, [(2, 0, 1), (13, 0, 0)]), # 14
("FusedMatMul", False, [(1, 0, 0), (13, 0, 1)]), # 15
("FusedMatMul", False, [(5, 0, 0)]), # 16
("Transpose", True, [(16, 0, 1)]), # 17
("Transpose", False, [(14, 0, 0)]), # 18
("Transpose", False, [(15, 0, 0)]), # 19
("Transpose", False, [(16, 0, 0)]), # 20
]
def _optimize_for_pattern_1(matcher: GraphMatcher, idx: int, nodes: List[NodeProto]):
# Check forward only as the backward is expected to be consistent if it's built correctly.
scale_value = matcher.get_constant_value(nodes[3].input[1])
if not (
check_attribute_value(nodes[1], "perm", [0, 2, 1, 3])
and check_attribute_value(nodes[2], "perm", [0, 2, 3, 1])
and scale_value is not None
and check_attribute_value(nodes[7], "perm", [0, 2, 1, 3])
and check_attribute_value(nodes[8], "perm", [0, 2, 1, 3])
):
return [], [], []
_, add_input_shape_0 = matcher.get_type_and_shape(nodes[4].input[0])
_, add_input_shape_1 = matcher.get_type_and_shape(nodes[4].input[1])
nodes_to_add, new_value_infos = _make_efficient_attention_nodes(
idx,
nodes[1].input[0],
nodes[2].input[0],
nodes[7].input[0],
nodes[8].output[0],
nodes[17].input[0],
nodes[18].output[0],
nodes[19].output[0],
nodes[20].output[0],
nodes[4].input[1],
add_input_shape_0 != add_input_shape_1,
1 / float(scale_value[0] if isinstance(scale_value, list) else scale_value),
0.0,
False,
)
return nodes, nodes_to_add, new_value_infos
# No causal mask, no attention mask, without Dropout.
_PATTERN_2: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [
("MatMul", False, []), # 0
("Mul", True, [(0, 0, 0)]), # 1
("Mul", True, [(0, 0, 1)]), # 2
("Cast", True, [(1, 0, 0)]), # 3
("Cast", True, [(2, 0, 0)]), # 4
("Transpose", True, [(3, 0, 0)]), # 5
("Transpose", True, [(4, 0, 0)]), # 6
("Softmax", False, [(0, 0, 0)]), # 7
("Cast", False, [(7, 0, 0)]), # 8
("MatMul", False, [(8, 0, 0)]), # 9
("Transpose", True, [(9, 0, 1)]), # 10
("Transpose", False, [(9, 0, 0)]), # 11
("FusedMatMul", False, [(10, 0, 1)]), # 12
("Cast", False, [(12, 0, 0)]), # 13
("SoftmaxGrad_13", False, [(13, 0, 0), (7, 0, 1)]), # 14
("FusedMatMul", False, [(2, 0, 1), (14, 0, 0)]), # 15
("FusedMatMul", False, [(1, 0, 0), (14, 0, 1)]), # 16
("Mul", False, [(15, 0, 0)]), # 17
("Mul", False, [(16, 0, 0)]), # 18
("Identity", False, [(17, 0, 0)]), # 19
("Identity", False, [(18, 0, 0)]), # 20
("Cast", False, [(19, 0, 0)]), # 21
("Cast", False, [(20, 0, 0)]), # 22
("Transpose", False, [(21, 0, 0)]), # 23
("Transpose", False, [(22, 0, 0)]), # 24
("FusedMatMul", False, [(8, 0, 0)]), # 25
("Transpose", True, [(25, 0, 1)]), # 26
("Transpose", False, [(25, 0, 0)]), # 27
]
def _optimize_for_pattern_2(matcher: GraphMatcher, idx: int, nodes: List[NodeProto]):
# Check forward only as the backward is expected to be consistent if it's built correctly.
scale_value_1 = matcher.get_constant_value(nodes[1].input[1])
scale_value_1 = scale_value_1[0] if isinstance(scale_value_1, list) else scale_value_1
scale_value_2 = matcher.get_constant_value(nodes[2].input[1])
scale_value_2 = scale_value_2[0] if isinstance(scale_value_2, list) else scale_value_2
if not (
check_attribute_value(nodes[3], "to", 1)
and check_attribute_value(nodes[4], "to", 1)
and check_attribute_value(nodes[5], "perm", [0, 2, 1, 3])
and check_attribute_value(nodes[6], "perm", [0, 2, 3, 1])
and check_attribute_value(nodes[8], "to", 10)
and check_attribute_value(nodes[10], "perm", [0, 2, 1, 3])
and check_attribute_value(nodes[11], "perm", [0, 2, 1, 3])
and scale_value_1 == scale_value_2
):
return [], [], []
nodes_to_add, new_value_infos = _make_efficient_attention_nodes(
idx,
nodes[5].input[0],
nodes[6].input[0],
nodes[10].input[0],
nodes[11].output[0],
nodes[26].input[0],
nodes[23].output[0],
nodes[24].output[0],
nodes[27].output[0],
"",
False,
scale_value_1,
0.0,
False,
)
return nodes, nodes_to_add, new_value_infos
# Has causal mask, no attention mask, without Dropout.
_PATTERN_3: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [
("MatMul", False, []), # 0
("Mul", True, [(0, 0, 0)]), # 1
("Mul", True, [(0, 0, 1)]), # 2
("Cast", True, [(1, 0, 0)]), # 3
("Cast", True, [(2, 0, 0)]), # 4
("Transpose", True, [(3, 0, 0)]), # 5
("Transpose", True, [(4, 0, 0)]), # 6
("Add", False, [(0, 0, 0)]), # 7
("Cast", True, [(7, 0, 1)]), # 8
("Slice", True, [(8, 0, 0)]), # 9
("Slice", True, [(9, 0, 0)]), # 10
("Unsqueeze", True, [(9, 0, 2)]), # 11
("Gather", True, [(11, 0, 0)]), # 12
("Shape", True, [(12, 0, 0)]), # 13
("Softmax", False, [(7, 0, 0)]), # 14
("Cast", False, [(14, 0, 0)]), # 15
("MatMul", False, [(15, 0, 0)]), # 16
("Transpose", True, [(16, 0, 1)]), # 17
("Transpose", False, [(16, 0, 0)]), # 18
("FusedMatMul", False, [(17, 0, 1)]), # 19
("Cast", False, [(19, 0, 0)]), # 20
("SoftmaxGrad_13", False, [(20, 0, 0), (14, 0, 1)]), # 21
("Identity", False, [(21, 0, 0)]), # 22
("FusedMatMul", False, [(2, 0, 1), (22, 0, 0)]), # 23
("FusedMatMul", False, [(1, 0, 0), (22, 0, 1)]), # 24
("Mul", False, [(23, 0, 0)]), # 25
("Mul", False, [(24, 0, 0)]), # 26
("Identity", False, [(25, 0, 0)]), # 27
("Identity", False, [(26, 0, 0)]), # 28
("Cast", False, [(27, 0, 0)]), # 29
("Cast", False, [(28, 0, 0)]), # 30
("Transpose", False, [(29, 0, 0)]), # 31
("Transpose", False, [(30, 0, 0)]), # 32
("FusedMatMul", False, [(15, 0, 0)]), # 33
("Transpose", True, [(33, 0, 1)]), # 34
("Transpose", False, [(33, 0, 0)]), # 35
]
def _optimize_for_pattern_3(matcher: GraphMatcher, idx: int, nodes: List[NodeProto]):
# Check forward only as the backward is expected to be consistent if it's built correctly.
scale_value_1 = matcher.get_constant_value(nodes[1].input[1])
scale_value_1 = scale_value_1[0] if isinstance(scale_value_1, list) else scale_value_1
scale_value_2 = matcher.get_constant_value(nodes[2].input[1])
scale_value_2 = scale_value_2[0] if isinstance(scale_value_2, list) else scale_value_2
if not (
check_attribute_value(nodes[3], "to", 1)
and check_attribute_value(nodes[4], "to", 1)
and check_attribute_value(nodes[5], "perm", [0, 2, 1, 3])
and check_attribute_value(nodes[6], "perm", [0, 2, 3, 1])
and check_attribute_value(nodes[15], "to", 10)
and check_attribute_value(nodes[17], "perm", [0, 2, 1, 3])
and check_attribute_value(nodes[18], "perm", [0, 2, 1, 3])
and scale_value_1 == scale_value_2
):
return [], [], []
nodes_to_add, new_value_infos = _make_efficient_attention_nodes(
idx,
nodes[5].input[0],
nodes[6].input[0],
nodes[17].input[0],
nodes[18].output[0],
nodes[34].input[0],
nodes[31].output[0],
nodes[32].output[0],
nodes[35].output[0],
"",
False,
scale_value_1,
0.0,
True,
)
return nodes, nodes_to_add, new_value_infos
_PATTERNS = [
(_PATTERN_0, _optimize_for_pattern_0),
(_PATTERN_1, _optimize_for_pattern_1),
(_PATTERN_2, _optimize_for_pattern_2),
(_PATTERN_3, _optimize_for_pattern_3),
]
@register_graph_optimizer(devices="cuda")
def optimize_graph_for_aten_efficient_attention(graph: GraphProto):
nodes_to_remove = []
nodes_to_add = []
new_value_infos = []
matcher = GraphMatcher(graph)
idx = 0
for pattern_tuple in _PATTERNS:
for nodes in matcher.match_pattern(pattern_tuple[0]):
remove_nodes, add_nodes, add_value_infos = pattern_tuple[1](matcher, idx, nodes)
if len(add_nodes) > 0:
nodes_to_remove.extend(remove_nodes)
nodes_to_add.extend(add_nodes)
new_value_infos.extend(add_value_infos)
idx += 1
update_graph(graph, nodes_to_remove, nodes_to_add, new_value_infos)

Просмотреть файл

@ -0,0 +1,178 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import itertools
from typing import Any, Dict, List, Sequence, Tuple
import numpy as np
from onnx import GraphProto, NodeProto, TensorProto, helper, numpy_helper
def _get_attribute(node: NodeProto, attr_name: str, default_value: Any = None) -> Any:
"""Get attribute value from node by attribute key."""
found = [attr for attr in node.attribute if attr.name == attr_name]
if found:
return helper.get_attribute_value(found[0])
return default_value
def _to_numpy_array(node: Any) -> np.ndarray:
"""Convert Constant node or TensorProto to Python value."""
tensor = node
if isinstance(node, NodeProto):
tensor = _get_attribute(node, "value")
assert isinstance(tensor, TensorProto)
return numpy_helper.to_array(tensor).tolist()
class GraphMatcher:
"""Sub-graph matcher with given pattern.
GraphMatcher takes an ONNX graph to initialize. It tries to match sub-graphs to a given pattern and yield
matched sub-graphs (a list of matched nodes for each sub-graph) one by one.
Pattern is described by a list. Each entry of the list is a Tuple:
Tuple[str, bool, List[Tuple[int, int, int]]], e.g., ("FusedMatMul", False, [(2, 0, 1), (15, 0, 0)])
* First string is the Op type, e.g., "FusedMatMul".
* Second bool indicates it's producer node or consumer node for source node.
* There is a list to describe the edge infos of this node to other nodes, each edge is a tuple with 3 integers,
first integer is the index of the target node in the list, second integer is the output index of the edge,
and thrid integer is the input index of the edge.
For each entry, GraphMatcher used the first edge to lookup target node, and try to use make sure the sug-graph also
matches rest edge infos.
Note that when lookup target node, it will only take the first matched node as target node. For example, if a source
node has multiple "MatMul" consumers nodes comsuming same output, only the first "MatMul" node will be returned.
You need to avoid using such confusing edge info as the first edge info for node lookup. Try to use other edge to
avoid such confusion if possible.
"""
def __init__(self, graph: GraphProto):
self._graph: GraphProto = graph
self._op_type_to_nodes: Dict[str, List[NodeProto]] = {}
self._consumer_count: Dict[str, int] = {}
for node in graph.node:
if node.op_type not in self._op_type_to_nodes:
self._op_type_to_nodes[node.op_type] = []
self._op_type_to_nodes[node.op_type].append(node)
for input in node.input:
self._consumer_count[input] = self._consumer_count.get(input, 0) + 1
def _get_producer(self, arg: str, op_type: str, output_idx: int):
for node in self._op_type_to_nodes.get(op_type, []):
if (output_idx >= 0 and len(node.output) > output_idx and node.output[output_idx] == arg) or (
output_idx == -1 and arg in node.output
):
return node
return None
def _get_consumer(self, arg: str, op_type: str, input_idx: int):
for node in self._op_type_to_nodes.get(op_type, []):
if (input_idx >= 0 and len(node.input) > input_idx and node.input[input_idx] == arg) or (
input_idx == -1 and arg in node.input
):
return node
return None
def get_consumer_count(self, arg: str):
return self._consumer_count.get(arg, 0)
def get_constant_value(self, arg: str):
node_or_initializer = None
if "Constant" in self._op_type_to_nodes:
for node in self._op_type_to_nodes["Constant"]:
if arg in node.output:
node_or_initializer = node
break
if node_or_initializer is None:
for initializer in self._graph.initializer:
if arg == initializer.name:
node_or_initializer = initializer
break
if node_or_initializer is None:
return None
return _to_numpy_array(node_or_initializer)
def get_type_and_shape(self, arg: str):
value_infos = [
value_info
for value_info in itertools.chain(self._graph.input, self._graph.value_info)
if value_info.name == arg
]
if len(value_infos) > 0 and value_infos[0].type.tensor_type.HasField("shape"):
shape = []
for dim in value_infos[0].type.tensor_type.shape.dim:
if dim.dim_param:
shape.append(dim.dim_param)
else:
shape.append(dim.dim_value)
return value_infos[0].type.tensor_type.elem_type, shape
initializers = [initializer for initializer in self._graph.initializer if initializer.name == arg]
if len(initializers) > 0:
return initializers[0].data_type, initializers[0].dims
return None, None
def _match_pattern(self, node: NodeProto, pattern: List[Tuple[str, bool, List[Tuple[int, int, int]]]]):
nodes = [node]
for i in range(1, len(pattern)):
next_op_type = pattern[i][0]
is_producer = pattern[i][1]
node_idx, output_idx, input_idx = pattern[i][2][0]
next_node = (
self._get_producer(nodes[node_idx].input[input_idx], next_op_type, output_idx)
if is_producer
else self._get_consumer(nodes[node_idx].output[output_idx], next_op_type, input_idx)
)
if next_node is None:
return []
for j in range(1, len(pattern[i][2])):
node_idx, output_idx, input_idx = pattern[i][2][j]
assert output_idx >= 0 and input_idx >= 0
if (not is_producer and nodes[node_idx].output[output_idx] != next_node.input[input_idx]) or (
is_producer and next_node.output[output_idx] != nodes[node_idx].input[input_idx]
):
return []
nodes.append(next_node)
return nodes
def match_pattern(self, pattern: List[Tuple[str, bool, List[Tuple[int, int, int]]]]):
for node in self._op_type_to_nodes.get(pattern[0][0], []):
result = self._match_pattern(node, pattern)
if len(result) == len(pattern):
yield result
def check_attribute_value(node: NodeProto, attr_name: str, expected_value: Any):
"""Check if the attribute of given node has expected value."""
value = _get_attribute(node, attr_name)
return value == expected_value
def make_constant_node(name: str, dtype: TensorProto.DataType, dims: Sequence[int], vals: Any):
"""Create a constant node with given constant tensor (data type, shape, and data)."""
return helper.make_node(
"Constant",
inputs=[],
outputs=[name],
value=helper.make_tensor(name=name, data_type=dtype, dims=dims, vals=vals),
)
def update_graph(
graph: GraphProto,
nodes_to_remove: List[NodeProto],
nodes_to_add: List[NodeProto],
new_value_infos: List[TensorProto] = [], # noqa: B006
):
"""Update an ONNX graph by removing some nodes, and adding some new nodes and value infos."""
nodes = [node for node in graph.node if node not in nodes_to_remove]
nodes.extend(nodes_to_add)
graph.ClearField("node")
graph.node.extend(nodes)
if len(new_value_infos) > 0:
graph.value_info.extend(new_value_infos)

Просмотреть файл

@ -1,47 +0,0 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from typing import Callable
from onnx.onnx_ml_pb2 import GraphProto
class GraphTransformerRegistry:
_TRANSFORMER_FUNCS = {} # noqa: RUF012
@classmethod
def register(cls, target_modules: str, devices: str, priority: int, fn: Callable[[GraphProto], None]):
modules = []
if target_modules == "all":
modules.append("all")
else:
modules = target_modules.split("|")
for module in modules:
if module in cls._TRANSFORMER_FUNCS:
cls._TRANSFORMER_FUNCS[module].append((fn, devices, priority))
else:
cls._TRANSFORMER_FUNCS[module] = [(fn, devices, priority)]
@classmethod
def transform_all(cls, module_name: str, device: str, graph: GraphProto):
transformers_to_apply = []
if "all" in cls._TRANSFORMER_FUNCS:
transformers_to_apply.extend(cls._TRANSFORMER_FUNCS["all"])
if module_name in cls._TRANSFORMER_FUNCS:
transformers_to_apply.extend(cls._TRANSFORMER_FUNCS[module_name])
transformers_to_apply = [x for x in transformers_to_apply if x[1] == "all" or device in x[1]]
transformers_to_apply.sort(key=lambda x: x[2], reverse=True)
for fn, _, _ in transformers_to_apply:
fn(graph)
# target_modules can be multiple module names separated by "|", or "all" means apply to all modules.
# devices can be multiple device types separated by "|" or "all" means apply to all devices.
def register_graph_transformer(target_modules: str = "all", devices: str = "all", priority: int = 0):
def graph_transformer_wrapper(fn):
GraphTransformerRegistry.register(target_modules, devices, priority, fn)
return fn
return graph_transformer_wrapper

Просмотреть файл

@ -17,8 +17,8 @@ InlinedHashSet<size_t> TritonOp::GetBoolOutputs(size_t output_size) const {
InlinedHashSet<size_t> bool_outputs;
for (size_t i = 0; i < output_size; ++i) {
ORT_ENFORCE(i < Node().OutputDefs().size(), "Output index out of range.");
if (Node().OutputDefs()[i]->TypeAsProto()->tensor_type().elem_type() ==
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL) {
if (Node().OutputDefs()[i]->Exists() && Node().OutputDefs()[i]->TypeAsProto()->tensor_type().elem_type() ==
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL) {
bool_outputs.insert(i);
}
}
@ -37,13 +37,15 @@ Status TritonOp::Compute(OpKernelContext* context) const {
InlinedHashSet<size_t> bool_outputs = GetBoolOutputs(output_size);
auto& executor = training::framework::triton::TritonOpExecutor::Instance();
if (func_name_ != "") {
executor.ExecuteByFuncName(func_name_, inputs, outputs, bool_outputs);
executor.ExecuteByFuncName(func_name_, inputs, outputs, bool_outputs, kwargs_);
} else {
executor.ExecuteByOnnx(onnx_key_, onnx_string_, inputs, outputs, bool_outputs);
}
ORT_ENFORCE(output_size == outputs.size());
for (size_t i = 0; i < output_size; ++i) {
ORT_THROW_IF_ERROR(p_ctx_internal->SetOutputMLValue(static_cast<int>(i), outputs[i]));
if (Node().OutputDefs()[i]->Exists()) {
ORT_THROW_IF_ERROR(p_ctx_internal->SetOutputMLValue(static_cast<int>(i), outputs[i]));
}
}
return Status::OK();
}

Просмотреть файл

@ -5,6 +5,8 @@
#pragma once
#include "core/common/inlined_containers.h"
#ifndef SHARED_PROVIDER
#include "core/framework/op_kernel.h"
#endif
@ -18,6 +20,19 @@ class TritonOp final : public OpKernel {
ORT_THROW_IF_ERROR(info.GetAttr("func_name", &func_name_));
ORT_THROW_IF_ERROR(info.GetAttr("onnx_key", &onnx_key_));
ORT_THROW_IF_ERROR(info.GetAttr("onnx_string", &onnx_string_));
for (const auto& attr : info.node().GetAttributes()) {
if (attr.first.rfind("_", 0) == 0 || attr.first == "func_name" || attr.first == "onnx_key" ||
attr.first == "onnx_string") {
continue;
}
// Support int64 and float only for now, skip other types.
if (attr.second.type() == ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_INT) {
kwargs_.insert({attr.first, {std::to_string(attr.second.i()), ONNX_NAMESPACE::TensorProto_DataType_INT64}});
} else if (attr.second.type() ==
ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_FLOAT) {
kwargs_.insert({attr.first, {std::to_string(attr.second.f()), ONNX_NAMESPACE::TensorProto_DataType_FLOAT}});
}
}
}
Status Compute(OpKernelContext* context) const override;
@ -28,6 +43,7 @@ class TritonOp final : public OpKernel {
std::string func_name_;
int64_t onnx_key_;
std::string onnx_string_;
InlinedHashMap<std::string, std::pair<std::string, int>> kwargs_;
};
bool IsTritonOpExecutorInitialized();

Просмотреть файл

@ -92,3 +92,4 @@ unfixable = [
"tools/nuget/generate_nuspec_for_native_nuget.py" = ["ISC003"] # Too many errors to fix
"onnxruntime/test/python/quantization/test_op_gemm.py" = ["N806"] # use of A for a matrix
"onnxruntime/test/python/quantization/op_test_utils.py" = ["N806", "PERF203", "RUF012"] # use of A for a matrix
"orttraining/orttraining/python/training/ort_triton/kernel/_flash_attn.py" = ["N806", "PLW2901", "ISC001", "E731"] # Long triton code from other repo.

Просмотреть файл

@ -466,6 +466,7 @@ if enable_training or enable_training_apis:
"onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.torch_interop_utils",
"onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.torch_gpu_allocator",
"onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.fused_ops",
"onnxruntime.training.ortmodule.graph_optimizers",
"onnxruntime.training.ort_triton",
"onnxruntime.training.ort_triton.kernel",
"onnxruntime.training.utils",