Adds ATen fallback for scaled_dot_product_attention (#21107)
### Description <!-- Describe your changes. --> Introduces an ATen fallback for `torch.nn.functional.scaled_dot_product_attention`. This operator was introduced in torch 2.0 and, since then, has had many updates including the implementation of memory efficient attention for V100 machines. The current torchscript exporter exports a subgraph for attention which does not provide the same memory savings that PyTorch's memory efficient attention kernel provides. Allowing fallback to PyTorch ATen op for attention helps mitigate memory spike issues for models leveraging memory efficient attention. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Memory issues arose when integrating ONNX Runtime Training with AML Stable Diffusion. --------- Co-authored-by: root <prathikrao@microsoft.com>
This commit is contained in:
Родитель
5b9369e93c
Коммит
11ad299451
|
@ -304,6 +304,16 @@ A classical usage of disabling the deep copy: when the deep copy before module e
|
|||
export ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT=0 # Disable
|
||||
```
|
||||
|
||||
#### ORTMODULE_ATEN_SDPA_FALLBACK
|
||||
|
||||
- **Feature Area**: *ORTMODULE/Optimizations*
|
||||
- **Description**: By default, this is disabled. This env var can be used for enabling pre-export attention fall back to PyTorch's [_scaled_dot_product_efficient_attention](https://github.com/pytorch/pytorch/blob/c12a4f2e65ad41b739aab1a261e2336b4a79fcfb/aten/src/ATen/native/native_functions.yaml#L14778) ATen kernel for execution when calling torch.nn.functional.scaled_dot_product_attention. NOTE: only use this feature if user model leverages memory efficient attention WITHOUT masking (ie. attn_mask=None). Utilize GPU profiling looks like NVIDIA Nsight Systems to identify if user model leverages memory efficient attention.
|
||||
|
||||
```bash
|
||||
export ORTMODULE_ATEN_SDPA_FALLBACK=1 # ENABLE
|
||||
unset ORTMODULE_ATEN_SDPA_FALLBACK # DISABLE
|
||||
```
|
||||
|
||||
### 2.2 Memory Optimization
|
||||
|
||||
Q: *Want to run a bigger batch size?*
|
||||
|
|
|
@ -1794,7 +1794,20 @@ IMPLEMENT_GRADIENT_BUILDER(GetExternalGradient) {
|
|||
}
|
||||
|
||||
std::vector<ArgDef> output_args;
|
||||
for (const auto& output : node_def.outputs) {
|
||||
for (size_t output_index = 0; output_index < node_def.outputs.size(); ++output_index) {
|
||||
// If the input is not used in the forward computation, we don't need it for gradient computation
|
||||
// Required for ORTMODULE_ATEN_SDPA_FALLBACK
|
||||
if (static_cast<int>(output_index) >= GetSrcNodeInputSize()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!IsGradientRequiredForSrcNodeInput(static_cast<int>(output_index))) {
|
||||
output_args.emplace_back(ArgDef());
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto& output = node_def.outputs[output_index];
|
||||
|
||||
if (output.find("GI(") == 0) {
|
||||
size_t index = static_cast<size_t>(std::stoi(output.substr(3, output.length() - 4)));
|
||||
output_args.emplace_back(GI(index));
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
# 'is_tensor' is optional, if not present, the default is False.
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from onnxruntime.capi import _pybind_state as C
|
||||
|
||||
|
@ -276,3 +277,39 @@ def upsample_nearest3d_gradient():
|
|||
@register_gradient("org.pytorch.aten", "ATen", "upsample_bicubic2d", "vec")
|
||||
def upsample_bicubic2d_gradient():
|
||||
return _upsample_gradient("upsample_bicubic2d_backward", 2)
|
||||
|
||||
|
||||
ATEN_SDPA_FALLBACK = os.getenv("ORTMODULE_ATEN_SDPA_FALLBACK", None)
|
||||
if ATEN_SDPA_FALLBACK:
|
||||
# based on the following internal PyTorch kernel for efficient attention:
|
||||
# https://github.com/pytorch/pytorch/blob/c12a4f2e65ad41b739aab1a261e2336b4a79fcfb/aten/src/ATen/native/native_functions.yaml#L14784
|
||||
@register_gradient("org.pytorch.aten", "ATen", "_scaled_dot_product_efficient_attention", "")
|
||||
def scaled_dot_product_attention_gradient():
|
||||
return [
|
||||
(
|
||||
"Constant",
|
||||
[],
|
||||
["grad_input_mask"],
|
||||
{"value": {"value": [1, 1, 1, 0], "dtype": "int", "is_tensor": True}},
|
||||
),
|
||||
(
|
||||
("ATen", "org.pytorch.aten"),
|
||||
[
|
||||
"GO(0)",
|
||||
"I(0)",
|
||||
"I(1)",
|
||||
"I(2)",
|
||||
"I(3)",
|
||||
"O(0)",
|
||||
"O(1)",
|
||||
"O(2)",
|
||||
"O(3)",
|
||||
"I(5)",
|
||||
"grad_input_mask",
|
||||
"I(6)",
|
||||
"I(7)",
|
||||
],
|
||||
["GI(0)", "GI(1)", "GI(2)", ""],
|
||||
{"operator": {"value": "_scaled_dot_product_efficient_attention_backward", "dtype": "string"}},
|
||||
),
|
||||
]
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import os
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
|
@ -969,3 +970,26 @@ def softmax(g, input, dim, dtype=None):
|
|||
softmax = g.op("Softmax", casted_input, axis_i=dim)
|
||||
|
||||
return softmax
|
||||
|
||||
|
||||
ATEN_SDPA_FALLBACK = os.getenv("ORTMODULE_ATEN_SDPA_FALLBACK", None)
|
||||
if ATEN_SDPA_FALLBACK:
|
||||
# based on the following internal PyTorch kernel for efficient attention:
|
||||
# https://github.com/pytorch/pytorch/blob/c12a4f2e65ad41b739aab1a261e2336b4a79fcfb/aten/src/ATen/native/native_functions.yaml#L14778
|
||||
@register_symbolic("scaled_dot_product_attention")
|
||||
def scaled_dot_product_attention(g, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
|
||||
dropout_p_f = g.op("Cast", dropout_p, to_i=torch.onnx.TensorProtoDataType.FLOAT)
|
||||
compute_logsumexp = g.op("Constant", value_t=torch.tensor([1], dtype=torch.bool))
|
||||
return g.op(
|
||||
"org.pytorch.aten::ATen",
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask,
|
||||
compute_logsumexp,
|
||||
dropout_p_f,
|
||||
is_causal,
|
||||
scale,
|
||||
operator_s="_scaled_dot_product_efficient_attention",
|
||||
outputs=4,
|
||||
)[0]
|
||||
|
|
|
@ -6953,3 +6953,74 @@ def test_layerwise_recompute_pythonop_determinstic():
|
|||
else:
|
||||
if "ORTMODULE_MEMORY_OPT_LEVEL" in os.environ:
|
||||
del os.environ["ORTMODULE_MEMORY_OPT_LEVEL"]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
Version(torch.__version__) < Version("2.3.0"),
|
||||
reason="torch.nn.attention module was introduced in PyTorch 2.3.0",
|
||||
)
|
||||
def test_aten_attention():
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
|
||||
class _NeuralNetAttention(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, q, k, v, attn_mask=None):
|
||||
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
|
||||
return torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask)
|
||||
|
||||
def gen_inputs(device, dtype):
|
||||
return [
|
||||
torch.randn(32, 8, 128, 64, dtype=dtype, device=device, requires_grad=True),
|
||||
torch.randn(32, 8, 128, 64, dtype=dtype, device=device, requires_grad=True),
|
||||
torch.randn(32, 8, 128, 64, dtype=dtype, device=device, requires_grad=True),
|
||||
]
|
||||
|
||||
def run_step(model, inputs, attn_mask=None):
|
||||
prediction = model(*inputs, attn_mask)
|
||||
prediction.sum().backward()
|
||||
return prediction
|
||||
|
||||
device = "cuda"
|
||||
|
||||
os.environ["ORTMODULE_ATEN_SDPA_FALLBACK"] = "1" # TESTING WITHOUT ATTN_MASK
|
||||
|
||||
pt_model = _NeuralNetAttention().to(device)
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model), DebugOptions(save_onnx=True, onnx_prefix="mem_eff_attn"))
|
||||
|
||||
# reset manual seed to reset the generator
|
||||
torch.manual_seed(2333)
|
||||
pt_input = gen_inputs(device=device, dtype=torch.float32)
|
||||
ort_input = copy.deepcopy(pt_input)
|
||||
pt_prediction = run_step(pt_model, pt_input)
|
||||
ort_prediction = run_step(ort_model, ort_input)
|
||||
|
||||
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
|
||||
_test_helpers.assert_values_are_close(ort_input[0].grad, pt_input[0].grad)
|
||||
_test_helpers.assert_values_are_close(ort_input[1].grad, pt_input[1].grad)
|
||||
_test_helpers.assert_values_are_close(ort_input[2].grad, pt_input[2].grad)
|
||||
|
||||
execution_mgr = ort_model._torch_module._execution_manager._training_manager
|
||||
from onnxruntime.training.ortmodule._onnx_models import _get_onnx_file_name
|
||||
|
||||
path = os.path.join(
|
||||
execution_mgr._debug_options.save_onnx_models.path,
|
||||
_get_onnx_file_name(
|
||||
execution_mgr._debug_options.save_onnx_models.name_prefix, "execution_model", execution_mgr._export_mode
|
||||
),
|
||||
)
|
||||
|
||||
onnx_model = onnx.load(path)
|
||||
onnx_nodes = onnx_model.graph.node
|
||||
|
||||
mem_eff_attn_nodes = 0
|
||||
for node in onnx_nodes:
|
||||
if "ATen" in node.name:
|
||||
for attr in node.attribute:
|
||||
if b"_scaled_dot_product_efficient_attention" in attr.s:
|
||||
mem_eff_attn_nodes += 1
|
||||
|
||||
assert mem_eff_attn_nodes > 0, "No mem_eff_attn nodes are found"
|
||||
|
||||
del os.environ["ORTMODULE_ATEN_SDPA_FALLBACK"]
|
||||
|
|
Загрузка…
Ссылка в новой задаче