Adds ATen fallback for scaled_dot_product_attention (#21107)

### Description
<!-- Describe your changes. -->

Introduces an ATen fallback for
`torch.nn.functional.scaled_dot_product_attention`. This operator was
introduced in torch 2.0 and, since then, has had many updates including
the implementation of memory efficient attention for V100 machines. The
current torchscript exporter exports a subgraph for attention which does
not provide the same memory savings that PyTorch's memory efficient
attention kernel provides. Allowing fallback to PyTorch ATen op for
attention helps mitigate memory spike issues for models leveraging
memory efficient attention.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Memory issues arose when integrating ONNX Runtime Training with AML
Stable Diffusion.

---------

Co-authored-by: root <prathikrao@microsoft.com>
This commit is contained in:
Prathik Rao 2024-07-22 16:37:04 -07:00 коммит произвёл GitHub
Родитель 5b9369e93c
Коммит 11ad299451
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
5 изменённых файлов: 156 добавлений и 1 удалений

Просмотреть файл

@ -304,6 +304,16 @@ A classical usage of disabling the deep copy: when the deep copy before module e
export ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT=0 # Disable
```
#### ORTMODULE_ATEN_SDPA_FALLBACK
- **Feature Area**: *ORTMODULE/Optimizations*
- **Description**: By default, this is disabled. This env var can be used for enabling pre-export attention fall back to PyTorch's [_scaled_dot_product_efficient_attention](https://github.com/pytorch/pytorch/blob/c12a4f2e65ad41b739aab1a261e2336b4a79fcfb/aten/src/ATen/native/native_functions.yaml#L14778) ATen kernel for execution when calling torch.nn.functional.scaled_dot_product_attention. NOTE: only use this feature if user model leverages memory efficient attention WITHOUT masking (ie. attn_mask=None). Utilize GPU profiling looks like NVIDIA Nsight Systems to identify if user model leverages memory efficient attention.
```bash
export ORTMODULE_ATEN_SDPA_FALLBACK=1 # ENABLE
unset ORTMODULE_ATEN_SDPA_FALLBACK # DISABLE
```
### 2.2 Memory Optimization
Q: *Want to run a bigger batch size?*

Просмотреть файл

@ -1794,7 +1794,20 @@ IMPLEMENT_GRADIENT_BUILDER(GetExternalGradient) {
}
std::vector<ArgDef> output_args;
for (const auto& output : node_def.outputs) {
for (size_t output_index = 0; output_index < node_def.outputs.size(); ++output_index) {
// If the input is not used in the forward computation, we don't need it for gradient computation
// Required for ORTMODULE_ATEN_SDPA_FALLBACK
if (static_cast<int>(output_index) >= GetSrcNodeInputSize()) {
continue;
}
if (!IsGradientRequiredForSrcNodeInput(static_cast<int>(output_index))) {
output_args.emplace_back(ArgDef());
continue;
}
const auto& output = node_def.outputs[output_index];
if (output.find("GI(") == 0) {
size_t index = static_cast<size_t>(std::stoi(output.substr(3, output.length() - 4)));
output_args.emplace_back(GI(index));

Просмотреть файл

@ -25,6 +25,7 @@
# 'is_tensor' is optional, if not present, the default is False.
import json
import os
from onnxruntime.capi import _pybind_state as C
@ -276,3 +277,39 @@ def upsample_nearest3d_gradient():
@register_gradient("org.pytorch.aten", "ATen", "upsample_bicubic2d", "vec")
def upsample_bicubic2d_gradient():
return _upsample_gradient("upsample_bicubic2d_backward", 2)
ATEN_SDPA_FALLBACK = os.getenv("ORTMODULE_ATEN_SDPA_FALLBACK", None)
if ATEN_SDPA_FALLBACK:
# based on the following internal PyTorch kernel for efficient attention:
# https://github.com/pytorch/pytorch/blob/c12a4f2e65ad41b739aab1a261e2336b4a79fcfb/aten/src/ATen/native/native_functions.yaml#L14784
@register_gradient("org.pytorch.aten", "ATen", "_scaled_dot_product_efficient_attention", "")
def scaled_dot_product_attention_gradient():
return [
(
"Constant",
[],
["grad_input_mask"],
{"value": {"value": [1, 1, 1, 0], "dtype": "int", "is_tensor": True}},
),
(
("ATen", "org.pytorch.aten"),
[
"GO(0)",
"I(0)",
"I(1)",
"I(2)",
"I(3)",
"O(0)",
"O(1)",
"O(2)",
"O(3)",
"I(5)",
"grad_input_mask",
"I(6)",
"I(7)",
],
["GI(0)", "GI(1)", "GI(2)", ""],
{"operator": {"value": "_scaled_dot_product_efficient_attention_backward", "dtype": "string"}},
),
]

Просмотреть файл

@ -3,6 +3,7 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import os
from typing import Callable
import torch
@ -969,3 +970,26 @@ def softmax(g, input, dim, dtype=None):
softmax = g.op("Softmax", casted_input, axis_i=dim)
return softmax
ATEN_SDPA_FALLBACK = os.getenv("ORTMODULE_ATEN_SDPA_FALLBACK", None)
if ATEN_SDPA_FALLBACK:
# based on the following internal PyTorch kernel for efficient attention:
# https://github.com/pytorch/pytorch/blob/c12a4f2e65ad41b739aab1a261e2336b4a79fcfb/aten/src/ATen/native/native_functions.yaml#L14778
@register_symbolic("scaled_dot_product_attention")
def scaled_dot_product_attention(g, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
dropout_p_f = g.op("Cast", dropout_p, to_i=torch.onnx.TensorProtoDataType.FLOAT)
compute_logsumexp = g.op("Constant", value_t=torch.tensor([1], dtype=torch.bool))
return g.op(
"org.pytorch.aten::ATen",
query,
key,
value,
attn_mask,
compute_logsumexp,
dropout_p_f,
is_causal,
scale,
operator_s="_scaled_dot_product_efficient_attention",
outputs=4,
)[0]

Просмотреть файл

@ -6953,3 +6953,74 @@ def test_layerwise_recompute_pythonop_determinstic():
else:
if "ORTMODULE_MEMORY_OPT_LEVEL" in os.environ:
del os.environ["ORTMODULE_MEMORY_OPT_LEVEL"]
@pytest.mark.skipif(
Version(torch.__version__) < Version("2.3.0"),
reason="torch.nn.attention module was introduced in PyTorch 2.3.0",
)
def test_aten_attention():
from torch.nn.attention import SDPBackend, sdpa_kernel
class _NeuralNetAttention(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, q, k, v, attn_mask=None):
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
return torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask)
def gen_inputs(device, dtype):
return [
torch.randn(32, 8, 128, 64, dtype=dtype, device=device, requires_grad=True),
torch.randn(32, 8, 128, 64, dtype=dtype, device=device, requires_grad=True),
torch.randn(32, 8, 128, 64, dtype=dtype, device=device, requires_grad=True),
]
def run_step(model, inputs, attn_mask=None):
prediction = model(*inputs, attn_mask)
prediction.sum().backward()
return prediction
device = "cuda"
os.environ["ORTMODULE_ATEN_SDPA_FALLBACK"] = "1" # TESTING WITHOUT ATTN_MASK
pt_model = _NeuralNetAttention().to(device)
ort_model = ORTModule(copy.deepcopy(pt_model), DebugOptions(save_onnx=True, onnx_prefix="mem_eff_attn"))
# reset manual seed to reset the generator
torch.manual_seed(2333)
pt_input = gen_inputs(device=device, dtype=torch.float32)
ort_input = copy.deepcopy(pt_input)
pt_prediction = run_step(pt_model, pt_input)
ort_prediction = run_step(ort_model, ort_input)
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
_test_helpers.assert_values_are_close(ort_input[0].grad, pt_input[0].grad)
_test_helpers.assert_values_are_close(ort_input[1].grad, pt_input[1].grad)
_test_helpers.assert_values_are_close(ort_input[2].grad, pt_input[2].grad)
execution_mgr = ort_model._torch_module._execution_manager._training_manager
from onnxruntime.training.ortmodule._onnx_models import _get_onnx_file_name
path = os.path.join(
execution_mgr._debug_options.save_onnx_models.path,
_get_onnx_file_name(
execution_mgr._debug_options.save_onnx_models.name_prefix, "execution_model", execution_mgr._export_mode
),
)
onnx_model = onnx.load(path)
onnx_nodes = onnx_model.graph.node
mem_eff_attn_nodes = 0
for node in onnx_nodes:
if "ATen" in node.name:
for attr in node.attribute:
if b"_scaled_dot_product_efficient_attention" in attr.s:
mem_eff_attn_nodes += 1
assert mem_eff_attn_nodes > 0, "No mem_eff_attn nodes are found"
del os.environ["ORTMODULE_ATEN_SDPA_FALLBACK"]