[ORTModule] ATen Efficient Attention and Triton Flash Attention (#17959)
This PR is to support efficient attention and flash attention in ORTModule, including: - Use ATen to call efficient attention, which requires PyTorch 2.2.0 dev or newer. ORTMODULE_USE_EFFICIENT_ATTENTION=1 to enable. - Integrate Triton Flash attention, which requires triton==2.0.0.dev20221202. Need A100 or H100. ORTMODULE_USE_FLASH_ATTENTION=1 to enable. - A python transformer tool to match sub-graph by config and write transformer quickly. Current transformers supports attention mask for both efficient attn and flash attn, and dropout for efficient attn only. To support more training scenarios (such as causal mask in GPT2), more transformers need to be added. The feature is guarded by system environment variables, it won't effect any current behavior if not enabled. Since it requires specific PyTorch/Triton versions, related tests is not added for now.
This commit is contained in:
Родитель
37873be86d
Коммит
b7408f7389
|
@ -387,6 +387,9 @@ if (onnxruntime_ENABLE_TRAINING)
|
|||
file(GLOB onnxruntime_python_ortmodule_torch_cpp_ext_fused_ops_srcs CONFIGURE_DEPENDS
|
||||
"${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/*"
|
||||
)
|
||||
file(GLOB onnxruntime_python_ortmodule_graph_optimizers_srcs CONFIGURE_DEPENDS
|
||||
"${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/graph_optimizers/*"
|
||||
)
|
||||
file(GLOB onnxruntime_python_ort_triton_srcs CONFIGURE_DEPENDS
|
||||
"${ORTTRAINING_SOURCE_DIR}/python/training/ort_triton/*.py"
|
||||
)
|
||||
|
@ -741,6 +744,7 @@ if (onnxruntime_ENABLE_TRAINING)
|
|||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/cuda/torch_gpu_allocator
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/cuda/fused_ops
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/graph_optimizers
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ort_triton
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ort_triton/kernel
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/utils
|
||||
|
@ -794,6 +798,9 @@ if (onnxruntime_ENABLE_TRAINING)
|
|||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${onnxruntime_python_ortmodule_torch_cpp_ext_fused_ops_srcs}
|
||||
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${onnxruntime_python_ortmodule_graph_optimizers_srcs}
|
||||
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/graph_optimizers/
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${onnxruntime_python_ort_triton_srcs}
|
||||
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ort_triton/
|
||||
|
|
|
@ -32,8 +32,10 @@ Status ATen::Compute(OpKernelContext* p_ctx) const {
|
|||
aten_ops::ATenOperatorExecutor::Instance()(op_name_, overload_name_, input_size, dlpack_inputs.get(), output_size,
|
||||
dlpack_outputs.get());
|
||||
for (size_t i = 0; i < output_size; ++i) {
|
||||
ORT_RETURN_IF_ERROR(
|
||||
p_ctx_internal->SetOutputMLValue(static_cast<int>(i), dlpack::DlpackToOrtValue(dlpack_outputs[i])));
|
||||
if (dlpack_outputs[i]) {
|
||||
ORT_RETURN_IF_ERROR(
|
||||
p_ctx_internal->SetOutputMLValue(static_cast<int>(i), dlpack::DlpackToOrtValue(dlpack_outputs[i])));
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
@ -10,7 +10,7 @@ namespace onnxruntime {
|
|||
namespace contrib {
|
||||
namespace aten_ops {
|
||||
|
||||
typedef bool (*IsTensorArgumentFunc)(const char* op_name, const char* overload_name, size_t index);
|
||||
typedef bool (*IsCpuArgumentFunc)(const char* op_name, const char* overload_name, size_t index, bool is_input);
|
||||
typedef void (*ExecuteATenOperatorFunc)(const char* op_name, const char* overload_name, size_t input_size,
|
||||
DLManagedTensor** dlpack_inputs, size_t output_size,
|
||||
DLManagedTensor** dlpack_outputs);
|
||||
|
@ -22,17 +22,17 @@ class ATenOperatorExecutor {
|
|||
return instance;
|
||||
}
|
||||
|
||||
void Initialize(void* p_is_tensor_argument_func_raw, void* p_execute_aten_op_func_raw) {
|
||||
ORT_ENFORCE(p_is_tensor_argument_func_raw && p_execute_aten_op_func_raw);
|
||||
p_is_tensor_argument_func_ = reinterpret_cast<IsTensorArgumentFunc>(p_is_tensor_argument_func_raw);
|
||||
void Initialize(void* p_is_cpu_argument_func_raw, void* p_execute_aten_op_func_raw) {
|
||||
ORT_ENFORCE(p_is_cpu_argument_func_raw && p_execute_aten_op_func_raw);
|
||||
p_is_cpu_argument_func_ = reinterpret_cast<IsCpuArgumentFunc>(p_is_cpu_argument_func_raw);
|
||||
p_execute_aten_op_func_ = reinterpret_cast<ExecuteATenOperatorFunc>(p_execute_aten_op_func_raw);
|
||||
}
|
||||
|
||||
bool IsInitialized() { return p_execute_aten_op_func_ != nullptr; }
|
||||
|
||||
bool IsTensorArgument(const std::string& op_name, const std::string& overload_name, size_t index) {
|
||||
ORT_ENFORCE(p_is_tensor_argument_func_, "ATenOperatorExecutor is not initialized.");
|
||||
return p_is_tensor_argument_func_(op_name.c_str(), overload_name.c_str(), index);
|
||||
bool IsCpuArgument(const std::string& op_name, const std::string& overload_name, size_t index, bool is_input) {
|
||||
ORT_ENFORCE(p_is_cpu_argument_func_, "ATenOperatorExecutor is not initialized.");
|
||||
return p_is_cpu_argument_func_(op_name.c_str(), overload_name.c_str(), index, is_input);
|
||||
}
|
||||
|
||||
void operator()(const std::string& op_name, const std::string& overload_name, size_t input_size,
|
||||
|
@ -43,7 +43,7 @@ class ATenOperatorExecutor {
|
|||
}
|
||||
|
||||
private:
|
||||
IsTensorArgumentFunc p_is_tensor_argument_func_ = nullptr;
|
||||
IsCpuArgumentFunc p_is_cpu_argument_func_ = nullptr;
|
||||
ExecuteATenOperatorFunc p_execute_aten_op_func_ = nullptr;
|
||||
};
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
#include "onnx/defs/data_type_utils.h"
|
||||
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/framework/utils.h"
|
||||
|
||||
using namespace ONNX_NAMESPACE::Utils;
|
||||
|
||||
|
@ -77,7 +78,7 @@ std::unordered_set<NodeIndex> GetCpuPreferredNodes(const onnxruntime::GraphViewe
|
|||
ORT_THROW_IF_ERROR(node->ForEachWithIndex(
|
||||
node->OutputDefs(),
|
||||
[&](const NodeArg& node_arg, size_t out_index) {
|
||||
if (kernel_info->kernel_def->IsOutputOnCpu(out_index)) {
|
||||
if (utils::IsOutputOnCpu(*node, kernel_info, out_index)) {
|
||||
cpu_output_args.insert(&node_arg);
|
||||
auto consumer_nodes = graph.GetConsumerNodes(node_arg.Name());
|
||||
for (auto& consumer_node : consumer_nodes) {
|
||||
|
|
|
@ -1025,7 +1025,32 @@ bool IsInputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index)
|
|||
overload_name = attrs.at("overload_name").s();
|
||||
}
|
||||
|
||||
return !contrib::aten_ops::ATenOperatorExecutor::Instance().IsTensorArgument(op_name, overload_name, index);
|
||||
return contrib::aten_ops::ATenOperatorExecutor::Instance().IsCpuArgument(op_name, overload_name, index, true);
|
||||
}
|
||||
#else
|
||||
ORT_UNUSED_PARAMETER(node);
|
||||
#endif
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsOutputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index) {
|
||||
if (p_kci && p_kci->kernel_def->IsOutputOnCpu(index)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_ATEN
|
||||
if (node.GetExecutionProviderType() == kCudaExecutionProvider && node.OpType() == "ATen" &&
|
||||
node.Domain() == kPytorchAtenDomain) {
|
||||
const auto& attrs = node.GetAttributes();
|
||||
ORT_ENFORCE(utils::HasString(attrs.at("operator")));
|
||||
std::string op_name = attrs.at("operator").s();
|
||||
std::string overload_name = "";
|
||||
if (attrs.find("overload_name") != attrs.end() && utils::HasString(attrs.at("overload_name"))) {
|
||||
overload_name = attrs.at("overload_name").s();
|
||||
}
|
||||
|
||||
return contrib::aten_ops::ATenOperatorExecutor::Instance().IsCpuArgument(op_name, overload_name, index, false);
|
||||
}
|
||||
#else
|
||||
ORT_UNUSED_PARAMETER(node);
|
||||
|
|
|
@ -121,6 +121,7 @@ common::Status ExecuteSubgraph(const SessionState& session_state, const FeedsFet
|
|||
bool sync_subgraph_fetches = false);
|
||||
|
||||
bool IsInputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index);
|
||||
bool IsOutputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index);
|
||||
|
||||
template <typename T>
|
||||
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() {
|
||||
|
|
|
@ -249,7 +249,7 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelReg
|
|||
if (!arg->Exists())
|
||||
continue;
|
||||
|
||||
if (kci && kci->kernel_def->IsOutputOnCpu(i))
|
||||
if (utils::IsOutputOnCpu(node, kci, i))
|
||||
non_provider_output_defs_.insert(arg);
|
||||
else
|
||||
provider_output_defs_.insert(arg);
|
||||
|
@ -308,7 +308,7 @@ void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg, co
|
|||
if (!kci || !utils::IsInputOnCpu(it, kci, arg_input_index)) provider_input_nodes_[arg].insert(&it);
|
||||
}
|
||||
if (arg_output_index != -1) {
|
||||
if (!kci || !kci->kernel_def->IsOutputOnCpu(arg_output_index)) provider_output_nodes_[arg].insert(&it);
|
||||
if (!kci || !utils::IsOutputOnCpu(it, kci, arg_output_index)) provider_output_nodes_[arg].insert(&it);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -404,8 +404,8 @@ bool TransformerMemcpyImpl::ProcessInitializers(const KernelRegistryManager& ker
|
|||
// normally initializers are only inputs, but things may change with ops like assign
|
||||
ORT_THROW_IF_ERROR(Node::ForEachWithIndex(
|
||||
p_node->OutputDefs(),
|
||||
[kci, &dup_replacements](const onnxruntime::NodeArg& arg, size_t index) {
|
||||
if (kci->kernel_def->IsOutputOnCpu(index)) {
|
||||
[kci, &p_node, &dup_replacements](const onnxruntime::NodeArg& arg, size_t index) {
|
||||
if (utils::IsOutputOnCpu(*p_node, kci, index)) {
|
||||
ORT_ENFORCE(dup_replacements.find(&arg) == dup_replacements.end());
|
||||
}
|
||||
return Status::OK();
|
||||
|
|
|
@ -1214,14 +1214,14 @@ void addGlobalMethods(py::module& m) {
|
|||
|
||||
#ifdef ENABLE_ATEN
|
||||
m.def("register_aten_op_executor",
|
||||
[](const std::string& is_tensor_argument_address_str, const std::string& aten_op_executor_address_str) -> void {
|
||||
size_t is_tensor_argument_address_int, aten_op_executor_address_int;
|
||||
[](const std::string& is_cpu_argument_address_str, const std::string& aten_op_executor_address_str) -> void {
|
||||
size_t is_cpu_argument_address_int, aten_op_executor_address_int;
|
||||
ORT_THROW_IF_ERROR(
|
||||
ParseStringWithClassicLocale(is_tensor_argument_address_str, is_tensor_argument_address_int));
|
||||
ParseStringWithClassicLocale(is_cpu_argument_address_str, is_cpu_argument_address_int));
|
||||
ORT_THROW_IF_ERROR(ParseStringWithClassicLocale(aten_op_executor_address_str, aten_op_executor_address_int));
|
||||
void* p_is_tensor_argument = reinterpret_cast<void*>(is_tensor_argument_address_int);
|
||||
void* p_is_cpu_argument = reinterpret_cast<void*>(is_cpu_argument_address_int);
|
||||
void* p_aten_op_executor = reinterpret_cast<void*>(aten_op_executor_address_int);
|
||||
contrib::aten_ops::ATenOperatorExecutor::Instance().Initialize(p_is_tensor_argument, p_aten_op_executor);
|
||||
contrib::aten_ops::ATenOperatorExecutor::Instance().Initialize(p_is_cpu_argument, p_aten_op_executor);
|
||||
});
|
||||
#endif
|
||||
}
|
||||
|
|
|
@ -29,5 +29,5 @@ def load_aten_op_executor_cpp_extension():
|
|||
from onnxruntime.training.ortmodule.torch_cpp_extensions import aten_op_executor
|
||||
|
||||
_C.register_aten_op_executor(
|
||||
str(aten_op_executor.is_tensor_argument_address()), str(aten_op_executor.execute_aten_operator_address())
|
||||
str(aten_op_executor.is_cpu_argument_address()), str(aten_op_executor.execute_aten_operator_address())
|
||||
)
|
||||
|
|
|
@ -154,11 +154,32 @@ class ATenOperatorCache {
|
|||
std::unordered_map<std::pair<std::string, std::string>, ATenOperator, PairHash> ops_;
|
||||
};
|
||||
|
||||
// Backend uses this function to check if an argument is CPU input (non-tensor argument) or not.
|
||||
bool IsTensorArgument(const char* op_name, const char* overload_name, size_t index) {
|
||||
const auto& aten_op = ATenOperatorCache::Instance().GetOperator(op_name, overload_name);
|
||||
TORCH_INTERNAL_ASSERT(index < aten_op.argument_size);
|
||||
return aten_op.elem_kinds[index] == c10::TypeKind::TensorType;
|
||||
const std::unordered_map<std::string, std::unordered_set<size_t>> kCpuTensorInputsMap = {
|
||||
{"_efficient_attention_forward", {4, 5, 11, 12}}, {"_efficient_attention_backward", {6, 7, 12, 13}}};
|
||||
|
||||
const std::unordered_map<std::string, std::unordered_set<size_t>> kCpuTensorOutputsMap = {
|
||||
{"_efficient_attention_forward", {2, 3}}};
|
||||
|
||||
// Backend uses this function to check if an argument is CPU input or not.
|
||||
bool IsCpuArgument(const char* op_name, const char* overload_name, size_t index, bool is_input) {
|
||||
if (is_input) {
|
||||
// If the argument is non-tensor type, it's CPU argument.
|
||||
const auto& aten_op = ATenOperatorCache::Instance().GetOperator(op_name, overload_name);
|
||||
TORCH_INTERNAL_ASSERT(index < aten_op.argument_size);
|
||||
if (aten_op.elem_kinds[index] != c10::TypeKind::TensorType) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
std::string full_name = std::string(op_name);
|
||||
std::string overload_name_str = std::string(overload_name);
|
||||
if (overload_name_str != "") {
|
||||
full_name += ("." + overload_name_str);
|
||||
}
|
||||
|
||||
const auto& cpu_tensors_map = is_input ? kCpuTensorInputsMap : kCpuTensorOutputsMap;
|
||||
return cpu_tensors_map.find(full_name) != cpu_tensors_map.end() &&
|
||||
cpu_tensors_map.at(full_name).find(index) != cpu_tensors_map.at(full_name).end();
|
||||
}
|
||||
|
||||
void ExecuteATenOperator(const char* op_name, const char* overload_name, size_t input_size,
|
||||
|
@ -196,14 +217,15 @@ void ExecuteATenOperator(const char* op_name, const char* overload_name, size_t
|
|||
size_t output_index = 0;
|
||||
for (const auto& ret : torch::jit::pop(stack, output_size)) {
|
||||
const auto& tensor = ret.toTensor();
|
||||
dlpack_outputs[output_index++] = at::toDLPack(tensor.is_contiguous() ? tensor : tensor.contiguous());
|
||||
dlpack_outputs[output_index++] =
|
||||
tensor.defined() ? at::toDLPack(tensor.is_contiguous() ? tensor : tensor.contiguous()) : nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
size_t is_tensor_argument_address() { return reinterpret_cast<size_t>(&IsTensorArgument); }
|
||||
size_t is_cpu_argument_address() { return reinterpret_cast<size_t>(&IsCpuArgument); }
|
||||
size_t execute_aten_operator_address() { return reinterpret_cast<size_t>(&ExecuteATenOperator); }
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("is_tensor_argument_address", &is_tensor_argument_address, "Address of tensor argument check.");
|
||||
m.def("is_cpu_argument_address", &is_cpu_argument_address, "Address of tensor argument check.");
|
||||
m.def("execute_aten_operator_address", &execute_aten_operator_address, "Address of Aten operator executor");
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch # noqa: F401
|
|||
|
||||
from onnxruntime.capi import _pybind_state as _C
|
||||
|
||||
from .aten_op_executor import execute_aten_operator_address, is_tensor_argument_address
|
||||
from .aten_op_executor import execute_aten_operator_address, is_cpu_argument_address
|
||||
|
||||
|
||||
def run_once_aten_op_executor(f):
|
||||
|
@ -30,7 +30,7 @@ def run_once_aten_op_executor(f):
|
|||
|
||||
@run_once_aten_op_executor
|
||||
def load_aten_op_executor_cpp_extension():
|
||||
_C.register_aten_op_executor(str(is_tensor_argument_address()), str(execute_aten_operator_address()))
|
||||
_C.register_aten_op_executor(str(is_cpu_argument_address()), str(execute_aten_operator_address()))
|
||||
|
||||
|
||||
def init_aten_op_executor():
|
||||
|
|
|
@ -4180,6 +4180,7 @@ Return true if all elements are true and false otherwise.
|
|||
.Attr("func_name", "Function name of the Python Triton kernel.", AttributeProto::STRING, std::string(""))
|
||||
.Attr("onnx_key", "The hash key for the ONNX graph.", AttributeProto::INT, static_cast<int64_t>(0))
|
||||
.Attr("onnx_string", "The onnx string of the triton kernel.", AttributeProto::STRING, std::string(""))
|
||||
.AllowUncheckedAttributes()
|
||||
.Input(0, "inputs",
|
||||
"Input tensors. If to call an existing Python Triton kernel, "
|
||||
"the input count and order should match the arguments of the function. If to compute an ONNX graph, "
|
||||
|
|
|
@ -3,15 +3,28 @@
|
|||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
from ._mm import triton_gemm, triton_gemm_out, triton_matmul, triton_matmul_out
|
||||
from ._slice_scel import slice_scel, slice_scel_backward, transform_slice_scel
|
||||
import os
|
||||
|
||||
__all__ = [
|
||||
from ._mm import triton_gemm, triton_gemm_out, triton_matmul, triton_matmul_out # noqa: F401
|
||||
from ._slice_scel import optimize_graph_for_slice_scel, slice_scel, slice_scel_backward # noqa: F401
|
||||
|
||||
_all_kernels = [
|
||||
"triton_gemm",
|
||||
"triton_gemm_out",
|
||||
"triton_matmul",
|
||||
"triton_matmul_out",
|
||||
"slice_scel",
|
||||
"slice_scel_backward",
|
||||
"transform_slice_scel",
|
||||
]
|
||||
|
||||
_all_optimizers = [
|
||||
"optimize_graph_for_slice_scel",
|
||||
]
|
||||
|
||||
if "ORTMODULE_USE_FLASH_ATTENTION" in os.environ and int(os.getenv("ORTMODULE_USE_FLASH_ATTENTION")) == 1:
|
||||
from ._flash_attn import flash_attn_backward, flash_attn_forward, optimize_graph_for_flash_attention # noqa: F401
|
||||
|
||||
_all_kernels.extend(["flash_attn_forward", "flash_attn_backward"])
|
||||
_all_optimizers.append("optimize_graph_for_flash_attention")
|
||||
|
||||
__all__ = _all_kernels + _all_optimizers # noqa: PLE0605
|
||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -11,7 +11,7 @@ import triton
|
|||
import triton.language as tl
|
||||
from onnx import TensorProto, helper
|
||||
|
||||
from onnxruntime.training.ortmodule import register_graph_transformer
|
||||
from onnxruntime.training.ortmodule import register_graph_optimizer
|
||||
|
||||
from .._utils import get_attribute, to_numpy_array
|
||||
|
||||
|
@ -246,8 +246,8 @@ def _get_shape_related_nodes(graph, start_arg, sub_graph_nodes):
|
|||
args.append(output)
|
||||
|
||||
|
||||
@register_graph_transformer(devices="cuda")
|
||||
def transform_slice_scel(graph):
|
||||
@register_graph_optimizer(devices="cuda")
|
||||
def optimize_graph_for_slice_scel(graph):
|
||||
remove_nodes = []
|
||||
triton_nodes = []
|
||||
value_infos = []
|
||||
|
|
|
@ -124,7 +124,8 @@ def _are_deterministic_algorithms_enabled():
|
|||
return ORTMODULE_IS_DETERMINISTIC
|
||||
|
||||
|
||||
from .graph_transformer_registry import register_graph_transformer # noqa: E402, F401
|
||||
from .graph_optimizer_registry import register_graph_optimizer # noqa: E402, F401
|
||||
from .graph_optimizers import * # noqa: E402, F403
|
||||
from .options import DebugOptions, LogLevel # noqa: E402, F401
|
||||
|
||||
# ORTModule must be loaded only after all validation passes
|
||||
|
|
|
@ -21,7 +21,7 @@ from ._io import _FlattenedModule, _InputInfo, unflatten_user_output
|
|||
from ._logger import ORTModuleInitPhase, SuppressLogs, TrackTime
|
||||
from ._runtime_inspector import Phase
|
||||
from ._utils import save_tuning_results, set_tuning_results
|
||||
from .graph_transformer_registry import GraphTransformerRegistry
|
||||
from .graph_optimizer_registry import GraphOptimizerRegistry
|
||||
from .options import DebugOptions, _SkipCheck
|
||||
|
||||
|
||||
|
@ -369,7 +369,7 @@ class TrainingManager(GraphExecutionManager):
|
|||
device_type = self._device.type
|
||||
if device_type == "cuda" and self.is_rocm_pytorch:
|
||||
device_type = "rocm"
|
||||
GraphTransformerRegistry.transform_all(
|
||||
GraphOptimizerRegistry.optimize_all(
|
||||
type(self._flattened_module._original_module).__name__, device_type, self._onnx_models.optimized_model.graph
|
||||
)
|
||||
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
from typing import Callable
|
||||
|
||||
from onnx.onnx_ml_pb2 import GraphProto
|
||||
|
||||
|
||||
class GraphOptimizerRegistry:
|
||||
_OPTIMIZER_FUNCS = {} # noqa: RUF012
|
||||
|
||||
@classmethod
|
||||
def register(cls, target_modules: str, devices: str, priority: int, fn: Callable[[GraphProto], None]):
|
||||
modules = []
|
||||
if target_modules == "all":
|
||||
modules.append("all")
|
||||
else:
|
||||
modules = target_modules.split("|")
|
||||
for module in modules:
|
||||
if module in cls._OPTIMIZER_FUNCS:
|
||||
cls._OPTIMIZER_FUNCS[module].append((fn, devices, priority))
|
||||
else:
|
||||
cls._OPTIMIZER_FUNCS[module] = [(fn, devices, priority)]
|
||||
|
||||
@classmethod
|
||||
def optimize_all(cls, module_name: str, device: str, graph: GraphProto):
|
||||
optimizers_to_apply = []
|
||||
if "all" in cls._OPTIMIZER_FUNCS:
|
||||
optimizers_to_apply.extend(cls._OPTIMIZER_FUNCS["all"])
|
||||
if module_name in cls._OPTIMIZER_FUNCS:
|
||||
optimizers_to_apply.extend(cls._OPTIMIZER_FUNCS[module_name])
|
||||
optimizers_to_apply = [x for x in optimizers_to_apply if x[1] == "all" or device in x[1]]
|
||||
optimizers_to_apply.sort(key=lambda x: x[2], reverse=True)
|
||||
for fn, _, _ in optimizers_to_apply:
|
||||
fn(graph)
|
||||
|
||||
|
||||
# target_modules can be multiple module names separated by "|", or "all" means apply to all modules.
|
||||
# devices can be multiple device types separated by "|" or "all" means apply to all devices.
|
||||
def register_graph_optimizer(target_modules: str = "all", devices: str = "all", priority: int = 0):
|
||||
def graph_optimizer_wrapper(fn):
|
||||
GraphOptimizerRegistry.register(target_modules, devices, priority, fn)
|
||||
return fn
|
||||
|
||||
return graph_optimizer_wrapper
|
|
@ -0,0 +1,15 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import os
|
||||
|
||||
_all_optimizers = []
|
||||
|
||||
if "ORTMODULE_USE_EFFICIENT_ATTENTION" in os.environ and int(os.getenv("ORTMODULE_USE_EFFICIENT_ATTENTION")) == 1:
|
||||
from ._aten_attn import optimize_graph_for_aten_efficient_attention # noqa: F401
|
||||
|
||||
_all_optimizers.append("optimize_graph_for_aten_efficient_attention")
|
||||
|
||||
__all__ = _all_optimizers # noqa: PLE0605
|
|
@ -0,0 +1,414 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
"""
|
||||
PyTorch's _efficient_attention_forward/_efficient_attention_backward APIs is keep changing. Current implementation
|
||||
is tested well on version 2.2.0.dev20231010+cu121, and should be run well since official version 2.2.0. If may fail to
|
||||
run is you are using PyTorch with older versions.
|
||||
|
||||
PyTorch also has API for flash attention (currently doesn't support random attention mask or Dropout), we can add
|
||||
support if we want to try in the future.
|
||||
"""
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
from onnx import GraphProto, NodeProto, TensorProto, helper
|
||||
|
||||
from ..graph_optimizer_registry import register_graph_optimizer
|
||||
from .utils import GraphMatcher, check_attribute_value, make_constant_node, update_graph
|
||||
|
||||
|
||||
def _make_efficient_attention_nodes(
|
||||
idx: int,
|
||||
q: str,
|
||||
k: str,
|
||||
v: str,
|
||||
y: str,
|
||||
dy: str,
|
||||
dq: str,
|
||||
dk: str,
|
||||
dv: str,
|
||||
bias: str,
|
||||
expand_bias: bool,
|
||||
scale: float,
|
||||
dropout_ratio: float,
|
||||
causal: bool,
|
||||
):
|
||||
nodes_to_add = []
|
||||
scale_node = make_constant_node("scale_" + str(idx), TensorProto.FLOAT, [], [scale])
|
||||
dropout_ratio_node = make_constant_node("dropout_ratio_" + str(idx), TensorProto.FLOAT, [], [dropout_ratio])
|
||||
causal_node = make_constant_node("causal_" + str(idx), TensorProto.INT64, [], [1 if causal else 0])
|
||||
int_zero_node = make_constant_node("int_zero_" + str(idx), TensorProto.INT64, [], [0])
|
||||
true_node = make_constant_node("true_" + str(idx), TensorProto.BOOL, [], [True])
|
||||
false_node = make_constant_node("false_" + str(idx), TensorProto.BOOL, [], [False])
|
||||
logsumexp = helper.make_tensor_value_info("logsumexp" + str(idx), TensorProto.FLOAT, [])
|
||||
seed = helper.make_tensor_value_info("seed" + str(idx), TensorProto.INT64, [])
|
||||
offset = helper.make_tensor_value_info("offset" + str(idx), TensorProto.INT64, [])
|
||||
new_value_infos = [logsumexp, seed, offset]
|
||||
if expand_bias:
|
||||
shape_0 = helper.make_node("Shape", [q], ["shape_0_" + str(idx)], start=0, end=1)
|
||||
shape_1 = helper.make_node("Shape", [q], ["shape_1_" + str(idx)], start=2, end=3)
|
||||
shape_2 = helper.make_node("Shape", [q], ["shape_2_" + str(idx)], start=1, end=2)
|
||||
shape_3 = helper.make_node("Shape", [k], ["shape_3_" + str(idx)], start=1, end=2)
|
||||
concat = helper.make_node(
|
||||
"Concat",
|
||||
["shape_0_" + str(idx), "shape_1_" + str(idx), "shape_2_" + str(idx), "shape_3_" + str(idx)],
|
||||
["concated_shape_" + str(idx)],
|
||||
axis=0,
|
||||
)
|
||||
expand = helper.make_node("Expand", [bias, "concated_shape_" + str(idx)], ["expanded_bias_" + str(idx)])
|
||||
nodes_to_add.extend([shape_0, shape_1, shape_2, shape_3, concat, expand])
|
||||
bias = "expanded_bias_" + str(idx)
|
||||
fwd_node = helper.make_node(
|
||||
"ATen",
|
||||
[
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
bias,
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
dropout_ratio_node.output[0],
|
||||
causal_node.output[0],
|
||||
true_node.output[0],
|
||||
scale_node.output[0],
|
||||
"",
|
||||
"",
|
||||
],
|
||||
[y, logsumexp.name, seed.name, offset.name],
|
||||
"efficient_attention_forward_" + str(idx),
|
||||
None,
|
||||
"org.pytorch.aten",
|
||||
operator="_efficient_attention_forward",
|
||||
)
|
||||
bwd_node = helper.make_node(
|
||||
"ATen",
|
||||
[
|
||||
dy,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
bias,
|
||||
y,
|
||||
"",
|
||||
"",
|
||||
int_zero_node.output[0],
|
||||
int_zero_node.output[0],
|
||||
logsumexp.name,
|
||||
dropout_ratio_node.output[0],
|
||||
seed.name,
|
||||
offset.name,
|
||||
causal_node.output[0],
|
||||
false_node.output[0],
|
||||
scale_node.output[0],
|
||||
"",
|
||||
],
|
||||
[dq, dk, dv, ""],
|
||||
"efficient_attention_backward_" + str(idx),
|
||||
None,
|
||||
"org.pytorch.aten",
|
||||
operator="_efficient_attention_backward",
|
||||
)
|
||||
nodes_to_add.extend(
|
||||
[scale_node, dropout_ratio_node, causal_node, int_zero_node, true_node, false_node, fwd_node, bwd_node]
|
||||
)
|
||||
return nodes_to_add, new_value_infos
|
||||
|
||||
|
||||
# Without causal mask, with Dropout. For example, BERT model in HuggingFace.
|
||||
_PATTERN_0: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [
|
||||
("MatMul", False, []), # 0
|
||||
("Transpose", True, [(0, 0, 0)]), # 1
|
||||
("Transpose", True, [(0, 0, 1)]), # 2
|
||||
("Div", False, [(0, 0, 0)]), # 3
|
||||
("Add", False, [(3, 0, 0)]), # 4
|
||||
("Softmax", False, [(4, 0, 0)]), # 5
|
||||
("Dropout", False, [(5, 0, 0)]), # 6
|
||||
("MatMul", False, [(6, 0, 0)]), # 7
|
||||
("Transpose", True, [(7, 0, 1)]), # 8
|
||||
("Transpose", False, [(7, 0, 0)]), # 9
|
||||
("FusedMatMul", False, [(8, 0, 1)]), # 10
|
||||
("DropoutGrad", False, [(10, 0, 0), (6, 1, 1)]), # 11
|
||||
("SoftmaxGrad_13", False, [(11, 0, 0), (5, 0, 1)]), # 12
|
||||
("Identity", False, [(12, 0, 0)]), # 13
|
||||
("Div", False, [(13, 0, 0)]), # 14
|
||||
("Identity", False, [(14, 0, 0)]), # 15
|
||||
("FusedMatMul", False, [(2, 0, 1), (15, 0, 0)]), # 16
|
||||
("FusedMatMul", False, [(1, 0, 0), (15, 0, 1)]), # 17
|
||||
("FusedMatMul", False, [(6, 0, 0)]), # 18
|
||||
("Transpose", True, [(18, 0, 1)]), # 19
|
||||
("Transpose", False, [(16, 0, 0)]), # 20
|
||||
("Transpose", False, [(17, 0, 0)]), # 21
|
||||
("Transpose", False, [(18, 0, 0)]), # 22
|
||||
]
|
||||
|
||||
|
||||
def _optimize_for_pattern_0(matcher: GraphMatcher, idx: int, nodes: List[NodeProto]):
|
||||
# Check forward only as the backward is expected to be consistent if it's built correctly.
|
||||
scale_value = matcher.get_constant_value(nodes[3].input[1])
|
||||
ratio_value = matcher.get_constant_value(nodes[6].input[1])
|
||||
if not (
|
||||
check_attribute_value(nodes[1], "perm", [0, 2, 1, 3])
|
||||
and check_attribute_value(nodes[2], "perm", [0, 2, 3, 1])
|
||||
and scale_value is not None
|
||||
and ratio_value is not None
|
||||
and check_attribute_value(nodes[8], "perm", [0, 2, 1, 3])
|
||||
and check_attribute_value(nodes[9], "perm", [0, 2, 1, 3])
|
||||
):
|
||||
return [], [], []
|
||||
|
||||
_, add_input_shape_0 = matcher.get_type_and_shape(nodes[4].input[0])
|
||||
_, add_input_shape_1 = matcher.get_type_and_shape(nodes[4].input[1])
|
||||
nodes_to_add, new_value_infos = _make_efficient_attention_nodes(
|
||||
idx,
|
||||
nodes[1].input[0],
|
||||
nodes[2].input[0],
|
||||
nodes[8].input[0],
|
||||
nodes[9].output[0],
|
||||
nodes[19].input[0],
|
||||
nodes[20].output[0],
|
||||
nodes[21].output[0],
|
||||
nodes[22].output[0],
|
||||
nodes[4].input[1],
|
||||
add_input_shape_0 != add_input_shape_1,
|
||||
1 / float(scale_value[0] if isinstance(scale_value, list) else scale_value),
|
||||
ratio_value,
|
||||
False,
|
||||
)
|
||||
return nodes, nodes_to_add, new_value_infos
|
||||
|
||||
|
||||
# Without causal mask, without Dropout. For example, BERT model and disabling attention dropout in HuggingFace.
|
||||
_PATTERN_1: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [
|
||||
("MatMul", False, []), # 0
|
||||
("Transpose", True, [(0, 0, 0)]), # 1
|
||||
("Transpose", True, [(0, 0, 1)]), # 2
|
||||
("Div", False, [(0, 0, 0)]), # 3
|
||||
("Add", False, [(3, 0, 0)]), # 4
|
||||
("Softmax", False, [(4, 0, 0)]), # 5
|
||||
("MatMul", False, [(5, 0, 0)]), # 6
|
||||
("Transpose", True, [(6, 0, 1)]), # 7
|
||||
("Transpose", False, [(6, 0, 0)]), # 8
|
||||
("FusedMatMul", False, [(7, 0, 1)]), # 9
|
||||
("SoftmaxGrad_13", False, [(9, 0, 0), (5, 0, 1)]), # 10
|
||||
("Identity", False, [(10, 0, 0)]), # 11
|
||||
("Div", False, [(11, 0, 0)]), # 12
|
||||
("Identity", False, [(12, 0, 0)]), # 13
|
||||
("FusedMatMul", False, [(2, 0, 1), (13, 0, 0)]), # 14
|
||||
("FusedMatMul", False, [(1, 0, 0), (13, 0, 1)]), # 15
|
||||
("FusedMatMul", False, [(5, 0, 0)]), # 16
|
||||
("Transpose", True, [(16, 0, 1)]), # 17
|
||||
("Transpose", False, [(14, 0, 0)]), # 18
|
||||
("Transpose", False, [(15, 0, 0)]), # 19
|
||||
("Transpose", False, [(16, 0, 0)]), # 20
|
||||
]
|
||||
|
||||
|
||||
def _optimize_for_pattern_1(matcher: GraphMatcher, idx: int, nodes: List[NodeProto]):
|
||||
# Check forward only as the backward is expected to be consistent if it's built correctly.
|
||||
scale_value = matcher.get_constant_value(nodes[3].input[1])
|
||||
if not (
|
||||
check_attribute_value(nodes[1], "perm", [0, 2, 1, 3])
|
||||
and check_attribute_value(nodes[2], "perm", [0, 2, 3, 1])
|
||||
and scale_value is not None
|
||||
and check_attribute_value(nodes[7], "perm", [0, 2, 1, 3])
|
||||
and check_attribute_value(nodes[8], "perm", [0, 2, 1, 3])
|
||||
):
|
||||
return [], [], []
|
||||
|
||||
_, add_input_shape_0 = matcher.get_type_and_shape(nodes[4].input[0])
|
||||
_, add_input_shape_1 = matcher.get_type_and_shape(nodes[4].input[1])
|
||||
nodes_to_add, new_value_infos = _make_efficient_attention_nodes(
|
||||
idx,
|
||||
nodes[1].input[0],
|
||||
nodes[2].input[0],
|
||||
nodes[7].input[0],
|
||||
nodes[8].output[0],
|
||||
nodes[17].input[0],
|
||||
nodes[18].output[0],
|
||||
nodes[19].output[0],
|
||||
nodes[20].output[0],
|
||||
nodes[4].input[1],
|
||||
add_input_shape_0 != add_input_shape_1,
|
||||
1 / float(scale_value[0] if isinstance(scale_value, list) else scale_value),
|
||||
0.0,
|
||||
False,
|
||||
)
|
||||
return nodes, nodes_to_add, new_value_infos
|
||||
|
||||
|
||||
# No causal mask, no attention mask, without Dropout.
|
||||
_PATTERN_2: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [
|
||||
("MatMul", False, []), # 0
|
||||
("Mul", True, [(0, 0, 0)]), # 1
|
||||
("Mul", True, [(0, 0, 1)]), # 2
|
||||
("Cast", True, [(1, 0, 0)]), # 3
|
||||
("Cast", True, [(2, 0, 0)]), # 4
|
||||
("Transpose", True, [(3, 0, 0)]), # 5
|
||||
("Transpose", True, [(4, 0, 0)]), # 6
|
||||
("Softmax", False, [(0, 0, 0)]), # 7
|
||||
("Cast", False, [(7, 0, 0)]), # 8
|
||||
("MatMul", False, [(8, 0, 0)]), # 9
|
||||
("Transpose", True, [(9, 0, 1)]), # 10
|
||||
("Transpose", False, [(9, 0, 0)]), # 11
|
||||
("FusedMatMul", False, [(10, 0, 1)]), # 12
|
||||
("Cast", False, [(12, 0, 0)]), # 13
|
||||
("SoftmaxGrad_13", False, [(13, 0, 0), (7, 0, 1)]), # 14
|
||||
("FusedMatMul", False, [(2, 0, 1), (14, 0, 0)]), # 15
|
||||
("FusedMatMul", False, [(1, 0, 0), (14, 0, 1)]), # 16
|
||||
("Mul", False, [(15, 0, 0)]), # 17
|
||||
("Mul", False, [(16, 0, 0)]), # 18
|
||||
("Identity", False, [(17, 0, 0)]), # 19
|
||||
("Identity", False, [(18, 0, 0)]), # 20
|
||||
("Cast", False, [(19, 0, 0)]), # 21
|
||||
("Cast", False, [(20, 0, 0)]), # 22
|
||||
("Transpose", False, [(21, 0, 0)]), # 23
|
||||
("Transpose", False, [(22, 0, 0)]), # 24
|
||||
("FusedMatMul", False, [(8, 0, 0)]), # 25
|
||||
("Transpose", True, [(25, 0, 1)]), # 26
|
||||
("Transpose", False, [(25, 0, 0)]), # 27
|
||||
]
|
||||
|
||||
|
||||
def _optimize_for_pattern_2(matcher: GraphMatcher, idx: int, nodes: List[NodeProto]):
|
||||
# Check forward only as the backward is expected to be consistent if it's built correctly.
|
||||
scale_value_1 = matcher.get_constant_value(nodes[1].input[1])
|
||||
scale_value_1 = scale_value_1[0] if isinstance(scale_value_1, list) else scale_value_1
|
||||
scale_value_2 = matcher.get_constant_value(nodes[2].input[1])
|
||||
scale_value_2 = scale_value_2[0] if isinstance(scale_value_2, list) else scale_value_2
|
||||
if not (
|
||||
check_attribute_value(nodes[3], "to", 1)
|
||||
and check_attribute_value(nodes[4], "to", 1)
|
||||
and check_attribute_value(nodes[5], "perm", [0, 2, 1, 3])
|
||||
and check_attribute_value(nodes[6], "perm", [0, 2, 3, 1])
|
||||
and check_attribute_value(nodes[8], "to", 10)
|
||||
and check_attribute_value(nodes[10], "perm", [0, 2, 1, 3])
|
||||
and check_attribute_value(nodes[11], "perm", [0, 2, 1, 3])
|
||||
and scale_value_1 == scale_value_2
|
||||
):
|
||||
return [], [], []
|
||||
|
||||
nodes_to_add, new_value_infos = _make_efficient_attention_nodes(
|
||||
idx,
|
||||
nodes[5].input[0],
|
||||
nodes[6].input[0],
|
||||
nodes[10].input[0],
|
||||
nodes[11].output[0],
|
||||
nodes[26].input[0],
|
||||
nodes[23].output[0],
|
||||
nodes[24].output[0],
|
||||
nodes[27].output[0],
|
||||
"",
|
||||
False,
|
||||
scale_value_1,
|
||||
0.0,
|
||||
False,
|
||||
)
|
||||
return nodes, nodes_to_add, new_value_infos
|
||||
|
||||
|
||||
# Has causal mask, no attention mask, without Dropout.
|
||||
_PATTERN_3: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [
|
||||
("MatMul", False, []), # 0
|
||||
("Mul", True, [(0, 0, 0)]), # 1
|
||||
("Mul", True, [(0, 0, 1)]), # 2
|
||||
("Cast", True, [(1, 0, 0)]), # 3
|
||||
("Cast", True, [(2, 0, 0)]), # 4
|
||||
("Transpose", True, [(3, 0, 0)]), # 5
|
||||
("Transpose", True, [(4, 0, 0)]), # 6
|
||||
("Add", False, [(0, 0, 0)]), # 7
|
||||
("Cast", True, [(7, 0, 1)]), # 8
|
||||
("Slice", True, [(8, 0, 0)]), # 9
|
||||
("Slice", True, [(9, 0, 0)]), # 10
|
||||
("Unsqueeze", True, [(9, 0, 2)]), # 11
|
||||
("Gather", True, [(11, 0, 0)]), # 12
|
||||
("Shape", True, [(12, 0, 0)]), # 13
|
||||
("Softmax", False, [(7, 0, 0)]), # 14
|
||||
("Cast", False, [(14, 0, 0)]), # 15
|
||||
("MatMul", False, [(15, 0, 0)]), # 16
|
||||
("Transpose", True, [(16, 0, 1)]), # 17
|
||||
("Transpose", False, [(16, 0, 0)]), # 18
|
||||
("FusedMatMul", False, [(17, 0, 1)]), # 19
|
||||
("Cast", False, [(19, 0, 0)]), # 20
|
||||
("SoftmaxGrad_13", False, [(20, 0, 0), (14, 0, 1)]), # 21
|
||||
("Identity", False, [(21, 0, 0)]), # 22
|
||||
("FusedMatMul", False, [(2, 0, 1), (22, 0, 0)]), # 23
|
||||
("FusedMatMul", False, [(1, 0, 0), (22, 0, 1)]), # 24
|
||||
("Mul", False, [(23, 0, 0)]), # 25
|
||||
("Mul", False, [(24, 0, 0)]), # 26
|
||||
("Identity", False, [(25, 0, 0)]), # 27
|
||||
("Identity", False, [(26, 0, 0)]), # 28
|
||||
("Cast", False, [(27, 0, 0)]), # 29
|
||||
("Cast", False, [(28, 0, 0)]), # 30
|
||||
("Transpose", False, [(29, 0, 0)]), # 31
|
||||
("Transpose", False, [(30, 0, 0)]), # 32
|
||||
("FusedMatMul", False, [(15, 0, 0)]), # 33
|
||||
("Transpose", True, [(33, 0, 1)]), # 34
|
||||
("Transpose", False, [(33, 0, 0)]), # 35
|
||||
]
|
||||
|
||||
|
||||
def _optimize_for_pattern_3(matcher: GraphMatcher, idx: int, nodes: List[NodeProto]):
|
||||
# Check forward only as the backward is expected to be consistent if it's built correctly.
|
||||
scale_value_1 = matcher.get_constant_value(nodes[1].input[1])
|
||||
scale_value_1 = scale_value_1[0] if isinstance(scale_value_1, list) else scale_value_1
|
||||
scale_value_2 = matcher.get_constant_value(nodes[2].input[1])
|
||||
scale_value_2 = scale_value_2[0] if isinstance(scale_value_2, list) else scale_value_2
|
||||
if not (
|
||||
check_attribute_value(nodes[3], "to", 1)
|
||||
and check_attribute_value(nodes[4], "to", 1)
|
||||
and check_attribute_value(nodes[5], "perm", [0, 2, 1, 3])
|
||||
and check_attribute_value(nodes[6], "perm", [0, 2, 3, 1])
|
||||
and check_attribute_value(nodes[15], "to", 10)
|
||||
and check_attribute_value(nodes[17], "perm", [0, 2, 1, 3])
|
||||
and check_attribute_value(nodes[18], "perm", [0, 2, 1, 3])
|
||||
and scale_value_1 == scale_value_2
|
||||
):
|
||||
return [], [], []
|
||||
|
||||
nodes_to_add, new_value_infos = _make_efficient_attention_nodes(
|
||||
idx,
|
||||
nodes[5].input[0],
|
||||
nodes[6].input[0],
|
||||
nodes[17].input[0],
|
||||
nodes[18].output[0],
|
||||
nodes[34].input[0],
|
||||
nodes[31].output[0],
|
||||
nodes[32].output[0],
|
||||
nodes[35].output[0],
|
||||
"",
|
||||
False,
|
||||
scale_value_1,
|
||||
0.0,
|
||||
True,
|
||||
)
|
||||
return nodes, nodes_to_add, new_value_infos
|
||||
|
||||
|
||||
_PATTERNS = [
|
||||
(_PATTERN_0, _optimize_for_pattern_0),
|
||||
(_PATTERN_1, _optimize_for_pattern_1),
|
||||
(_PATTERN_2, _optimize_for_pattern_2),
|
||||
(_PATTERN_3, _optimize_for_pattern_3),
|
||||
]
|
||||
|
||||
|
||||
@register_graph_optimizer(devices="cuda")
|
||||
def optimize_graph_for_aten_efficient_attention(graph: GraphProto):
|
||||
nodes_to_remove = []
|
||||
nodes_to_add = []
|
||||
new_value_infos = []
|
||||
matcher = GraphMatcher(graph)
|
||||
idx = 0
|
||||
for pattern_tuple in _PATTERNS:
|
||||
for nodes in matcher.match_pattern(pattern_tuple[0]):
|
||||
remove_nodes, add_nodes, add_value_infos = pattern_tuple[1](matcher, idx, nodes)
|
||||
if len(add_nodes) > 0:
|
||||
nodes_to_remove.extend(remove_nodes)
|
||||
nodes_to_add.extend(add_nodes)
|
||||
new_value_infos.extend(add_value_infos)
|
||||
idx += 1
|
||||
update_graph(graph, nodes_to_remove, nodes_to_add, new_value_infos)
|
|
@ -0,0 +1,178 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import itertools
|
||||
from typing import Any, Dict, List, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
from onnx import GraphProto, NodeProto, TensorProto, helper, numpy_helper
|
||||
|
||||
|
||||
def _get_attribute(node: NodeProto, attr_name: str, default_value: Any = None) -> Any:
|
||||
"""Get attribute value from node by attribute key."""
|
||||
found = [attr for attr in node.attribute if attr.name == attr_name]
|
||||
if found:
|
||||
return helper.get_attribute_value(found[0])
|
||||
return default_value
|
||||
|
||||
|
||||
def _to_numpy_array(node: Any) -> np.ndarray:
|
||||
"""Convert Constant node or TensorProto to Python value."""
|
||||
tensor = node
|
||||
if isinstance(node, NodeProto):
|
||||
tensor = _get_attribute(node, "value")
|
||||
assert isinstance(tensor, TensorProto)
|
||||
return numpy_helper.to_array(tensor).tolist()
|
||||
|
||||
|
||||
class GraphMatcher:
|
||||
"""Sub-graph matcher with given pattern.
|
||||
|
||||
GraphMatcher takes an ONNX graph to initialize. It tries to match sub-graphs to a given pattern and yield
|
||||
matched sub-graphs (a list of matched nodes for each sub-graph) one by one.
|
||||
|
||||
Pattern is described by a list. Each entry of the list is a Tuple:
|
||||
|
||||
Tuple[str, bool, List[Tuple[int, int, int]]], e.g., ("FusedMatMul", False, [(2, 0, 1), (15, 0, 0)])
|
||||
|
||||
* First string is the Op type, e.g., "FusedMatMul".
|
||||
* Second bool indicates it's producer node or consumer node for source node.
|
||||
* There is a list to describe the edge infos of this node to other nodes, each edge is a tuple with 3 integers,
|
||||
first integer is the index of the target node in the list, second integer is the output index of the edge,
|
||||
and thrid integer is the input index of the edge.
|
||||
|
||||
For each entry, GraphMatcher used the first edge to lookup target node, and try to use make sure the sug-graph also
|
||||
matches rest edge infos.
|
||||
|
||||
Note that when lookup target node, it will only take the first matched node as target node. For example, if a source
|
||||
node has multiple "MatMul" consumers nodes comsuming same output, only the first "MatMul" node will be returned.
|
||||
You need to avoid using such confusing edge info as the first edge info for node lookup. Try to use other edge to
|
||||
avoid such confusion if possible.
|
||||
"""
|
||||
|
||||
def __init__(self, graph: GraphProto):
|
||||
self._graph: GraphProto = graph
|
||||
self._op_type_to_nodes: Dict[str, List[NodeProto]] = {}
|
||||
self._consumer_count: Dict[str, int] = {}
|
||||
for node in graph.node:
|
||||
if node.op_type not in self._op_type_to_nodes:
|
||||
self._op_type_to_nodes[node.op_type] = []
|
||||
self._op_type_to_nodes[node.op_type].append(node)
|
||||
for input in node.input:
|
||||
self._consumer_count[input] = self._consumer_count.get(input, 0) + 1
|
||||
|
||||
def _get_producer(self, arg: str, op_type: str, output_idx: int):
|
||||
for node in self._op_type_to_nodes.get(op_type, []):
|
||||
if (output_idx >= 0 and len(node.output) > output_idx and node.output[output_idx] == arg) or (
|
||||
output_idx == -1 and arg in node.output
|
||||
):
|
||||
return node
|
||||
return None
|
||||
|
||||
def _get_consumer(self, arg: str, op_type: str, input_idx: int):
|
||||
for node in self._op_type_to_nodes.get(op_type, []):
|
||||
if (input_idx >= 0 and len(node.input) > input_idx and node.input[input_idx] == arg) or (
|
||||
input_idx == -1 and arg in node.input
|
||||
):
|
||||
return node
|
||||
return None
|
||||
|
||||
def get_consumer_count(self, arg: str):
|
||||
return self._consumer_count.get(arg, 0)
|
||||
|
||||
def get_constant_value(self, arg: str):
|
||||
node_or_initializer = None
|
||||
if "Constant" in self._op_type_to_nodes:
|
||||
for node in self._op_type_to_nodes["Constant"]:
|
||||
if arg in node.output:
|
||||
node_or_initializer = node
|
||||
break
|
||||
if node_or_initializer is None:
|
||||
for initializer in self._graph.initializer:
|
||||
if arg == initializer.name:
|
||||
node_or_initializer = initializer
|
||||
break
|
||||
if node_or_initializer is None:
|
||||
return None
|
||||
return _to_numpy_array(node_or_initializer)
|
||||
|
||||
def get_type_and_shape(self, arg: str):
|
||||
value_infos = [
|
||||
value_info
|
||||
for value_info in itertools.chain(self._graph.input, self._graph.value_info)
|
||||
if value_info.name == arg
|
||||
]
|
||||
if len(value_infos) > 0 and value_infos[0].type.tensor_type.HasField("shape"):
|
||||
shape = []
|
||||
for dim in value_infos[0].type.tensor_type.shape.dim:
|
||||
if dim.dim_param:
|
||||
shape.append(dim.dim_param)
|
||||
else:
|
||||
shape.append(dim.dim_value)
|
||||
return value_infos[0].type.tensor_type.elem_type, shape
|
||||
initializers = [initializer for initializer in self._graph.initializer if initializer.name == arg]
|
||||
if len(initializers) > 0:
|
||||
return initializers[0].data_type, initializers[0].dims
|
||||
return None, None
|
||||
|
||||
def _match_pattern(self, node: NodeProto, pattern: List[Tuple[str, bool, List[Tuple[int, int, int]]]]):
|
||||
nodes = [node]
|
||||
for i in range(1, len(pattern)):
|
||||
next_op_type = pattern[i][0]
|
||||
is_producer = pattern[i][1]
|
||||
node_idx, output_idx, input_idx = pattern[i][2][0]
|
||||
next_node = (
|
||||
self._get_producer(nodes[node_idx].input[input_idx], next_op_type, output_idx)
|
||||
if is_producer
|
||||
else self._get_consumer(nodes[node_idx].output[output_idx], next_op_type, input_idx)
|
||||
)
|
||||
if next_node is None:
|
||||
return []
|
||||
for j in range(1, len(pattern[i][2])):
|
||||
node_idx, output_idx, input_idx = pattern[i][2][j]
|
||||
assert output_idx >= 0 and input_idx >= 0
|
||||
if (not is_producer and nodes[node_idx].output[output_idx] != next_node.input[input_idx]) or (
|
||||
is_producer and next_node.output[output_idx] != nodes[node_idx].input[input_idx]
|
||||
):
|
||||
return []
|
||||
nodes.append(next_node)
|
||||
return nodes
|
||||
|
||||
def match_pattern(self, pattern: List[Tuple[str, bool, List[Tuple[int, int, int]]]]):
|
||||
for node in self._op_type_to_nodes.get(pattern[0][0], []):
|
||||
result = self._match_pattern(node, pattern)
|
||||
if len(result) == len(pattern):
|
||||
yield result
|
||||
|
||||
|
||||
def check_attribute_value(node: NodeProto, attr_name: str, expected_value: Any):
|
||||
"""Check if the attribute of given node has expected value."""
|
||||
value = _get_attribute(node, attr_name)
|
||||
return value == expected_value
|
||||
|
||||
|
||||
def make_constant_node(name: str, dtype: TensorProto.DataType, dims: Sequence[int], vals: Any):
|
||||
"""Create a constant node with given constant tensor (data type, shape, and data)."""
|
||||
return helper.make_node(
|
||||
"Constant",
|
||||
inputs=[],
|
||||
outputs=[name],
|
||||
value=helper.make_tensor(name=name, data_type=dtype, dims=dims, vals=vals),
|
||||
)
|
||||
|
||||
|
||||
def update_graph(
|
||||
graph: GraphProto,
|
||||
nodes_to_remove: List[NodeProto],
|
||||
nodes_to_add: List[NodeProto],
|
||||
new_value_infos: List[TensorProto] = [], # noqa: B006
|
||||
):
|
||||
"""Update an ONNX graph by removing some nodes, and adding some new nodes and value infos."""
|
||||
nodes = [node for node in graph.node if node not in nodes_to_remove]
|
||||
nodes.extend(nodes_to_add)
|
||||
graph.ClearField("node")
|
||||
graph.node.extend(nodes)
|
||||
if len(new_value_infos) > 0:
|
||||
graph.value_info.extend(new_value_infos)
|
|
@ -1,47 +0,0 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
from typing import Callable
|
||||
|
||||
from onnx.onnx_ml_pb2 import GraphProto
|
||||
|
||||
|
||||
class GraphTransformerRegistry:
|
||||
_TRANSFORMER_FUNCS = {} # noqa: RUF012
|
||||
|
||||
@classmethod
|
||||
def register(cls, target_modules: str, devices: str, priority: int, fn: Callable[[GraphProto], None]):
|
||||
modules = []
|
||||
if target_modules == "all":
|
||||
modules.append("all")
|
||||
else:
|
||||
modules = target_modules.split("|")
|
||||
for module in modules:
|
||||
if module in cls._TRANSFORMER_FUNCS:
|
||||
cls._TRANSFORMER_FUNCS[module].append((fn, devices, priority))
|
||||
else:
|
||||
cls._TRANSFORMER_FUNCS[module] = [(fn, devices, priority)]
|
||||
|
||||
@classmethod
|
||||
def transform_all(cls, module_name: str, device: str, graph: GraphProto):
|
||||
transformers_to_apply = []
|
||||
if "all" in cls._TRANSFORMER_FUNCS:
|
||||
transformers_to_apply.extend(cls._TRANSFORMER_FUNCS["all"])
|
||||
if module_name in cls._TRANSFORMER_FUNCS:
|
||||
transformers_to_apply.extend(cls._TRANSFORMER_FUNCS[module_name])
|
||||
transformers_to_apply = [x for x in transformers_to_apply if x[1] == "all" or device in x[1]]
|
||||
transformers_to_apply.sort(key=lambda x: x[2], reverse=True)
|
||||
for fn, _, _ in transformers_to_apply:
|
||||
fn(graph)
|
||||
|
||||
|
||||
# target_modules can be multiple module names separated by "|", or "all" means apply to all modules.
|
||||
# devices can be multiple device types separated by "|" or "all" means apply to all devices.
|
||||
def register_graph_transformer(target_modules: str = "all", devices: str = "all", priority: int = 0):
|
||||
def graph_transformer_wrapper(fn):
|
||||
GraphTransformerRegistry.register(target_modules, devices, priority, fn)
|
||||
return fn
|
||||
|
||||
return graph_transformer_wrapper
|
|
@ -17,8 +17,8 @@ InlinedHashSet<size_t> TritonOp::GetBoolOutputs(size_t output_size) const {
|
|||
InlinedHashSet<size_t> bool_outputs;
|
||||
for (size_t i = 0; i < output_size; ++i) {
|
||||
ORT_ENFORCE(i < Node().OutputDefs().size(), "Output index out of range.");
|
||||
if (Node().OutputDefs()[i]->TypeAsProto()->tensor_type().elem_type() ==
|
||||
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL) {
|
||||
if (Node().OutputDefs()[i]->Exists() && Node().OutputDefs()[i]->TypeAsProto()->tensor_type().elem_type() ==
|
||||
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL) {
|
||||
bool_outputs.insert(i);
|
||||
}
|
||||
}
|
||||
|
@ -37,13 +37,15 @@ Status TritonOp::Compute(OpKernelContext* context) const {
|
|||
InlinedHashSet<size_t> bool_outputs = GetBoolOutputs(output_size);
|
||||
auto& executor = training::framework::triton::TritonOpExecutor::Instance();
|
||||
if (func_name_ != "") {
|
||||
executor.ExecuteByFuncName(func_name_, inputs, outputs, bool_outputs);
|
||||
executor.ExecuteByFuncName(func_name_, inputs, outputs, bool_outputs, kwargs_);
|
||||
} else {
|
||||
executor.ExecuteByOnnx(onnx_key_, onnx_string_, inputs, outputs, bool_outputs);
|
||||
}
|
||||
ORT_ENFORCE(output_size == outputs.size());
|
||||
for (size_t i = 0; i < output_size; ++i) {
|
||||
ORT_THROW_IF_ERROR(p_ctx_internal->SetOutputMLValue(static_cast<int>(i), outputs[i]));
|
||||
if (Node().OutputDefs()[i]->Exists()) {
|
||||
ORT_THROW_IF_ERROR(p_ctx_internal->SetOutputMLValue(static_cast<int>(i), outputs[i]));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -5,6 +5,8 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "core/common/inlined_containers.h"
|
||||
|
||||
#ifndef SHARED_PROVIDER
|
||||
#include "core/framework/op_kernel.h"
|
||||
#endif
|
||||
|
@ -18,6 +20,19 @@ class TritonOp final : public OpKernel {
|
|||
ORT_THROW_IF_ERROR(info.GetAttr("func_name", &func_name_));
|
||||
ORT_THROW_IF_ERROR(info.GetAttr("onnx_key", &onnx_key_));
|
||||
ORT_THROW_IF_ERROR(info.GetAttr("onnx_string", &onnx_string_));
|
||||
for (const auto& attr : info.node().GetAttributes()) {
|
||||
if (attr.first.rfind("_", 0) == 0 || attr.first == "func_name" || attr.first == "onnx_key" ||
|
||||
attr.first == "onnx_string") {
|
||||
continue;
|
||||
}
|
||||
// Support int64 and float only for now, skip other types.
|
||||
if (attr.second.type() == ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_INT) {
|
||||
kwargs_.insert({attr.first, {std::to_string(attr.second.i()), ONNX_NAMESPACE::TensorProto_DataType_INT64}});
|
||||
} else if (attr.second.type() ==
|
||||
ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_FLOAT) {
|
||||
kwargs_.insert({attr.first, {std::to_string(attr.second.f()), ONNX_NAMESPACE::TensorProto_DataType_FLOAT}});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
|
@ -28,6 +43,7 @@ class TritonOp final : public OpKernel {
|
|||
std::string func_name_;
|
||||
int64_t onnx_key_;
|
||||
std::string onnx_string_;
|
||||
InlinedHashMap<std::string, std::pair<std::string, int>> kwargs_;
|
||||
};
|
||||
|
||||
bool IsTritonOpExecutorInitialized();
|
||||
|
|
|
@ -92,3 +92,4 @@ unfixable = [
|
|||
"tools/nuget/generate_nuspec_for_native_nuget.py" = ["ISC003"] # Too many errors to fix
|
||||
"onnxruntime/test/python/quantization/test_op_gemm.py" = ["N806"] # use of A for a matrix
|
||||
"onnxruntime/test/python/quantization/op_test_utils.py" = ["N806", "PERF203", "RUF012"] # use of A for a matrix
|
||||
"orttraining/orttraining/python/training/ort_triton/kernel/_flash_attn.py" = ["N806", "PLW2901", "ISC001", "E731"] # Long triton code from other repo.
|
||||
|
|
1
setup.py
1
setup.py
|
@ -466,6 +466,7 @@ if enable_training or enable_training_apis:
|
|||
"onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.torch_interop_utils",
|
||||
"onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.torch_gpu_allocator",
|
||||
"onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.fused_ops",
|
||||
"onnxruntime.training.ortmodule.graph_optimizers",
|
||||
"onnxruntime.training.ort_triton",
|
||||
"onnxruntime.training.ort_triton.kernel",
|
||||
"onnxruntime.training.utils",
|
||||
|
|
Загрузка…
Ссылка в новой задаче