2022-04-23 16:23:04 +03:00
|
|
|
[tool.black]
|
|
|
|
line-length = 120
|
2023-03-25 01:29:03 +03:00
|
|
|
# NOTE: Do not extend the exclude list. Edit .lintrunner.toml instead
|
2022-04-26 19:35:16 +03:00
|
|
|
extend-exclude = "cmake|onnxruntime/core/flatbuffers/"
|
2023-03-25 01:29:03 +03:00
|
|
|
target-version = ["py37", "py38", "py39", "py310", "py311"]
|
2022-04-23 16:23:04 +03:00
|
|
|
|
|
|
|
[tool.isort]
|
2023-03-25 01:29:03 +03:00
|
|
|
# NOTE: Do not extend the exclude list. Edit .lintrunner.toml instead
|
2022-04-23 16:23:04 +03:00
|
|
|
profile = "black"
|
|
|
|
line_length = 120
|
2022-04-26 19:35:16 +03:00
|
|
|
extend_skip_glob = [
|
|
|
|
"cmake/*",
|
|
|
|
"orttraining/*",
|
|
|
|
"onnxruntime/core/flatbuffers/*",
|
|
|
|
]
|
2022-04-23 16:23:04 +03:00
|
|
|
|
|
|
|
[tool.pydocstyle]
|
|
|
|
convention = "google"
|
2022-05-16 23:26:56 +03:00
|
|
|
|
2022-11-14 21:00:25 +03:00
|
|
|
[tool.pylint.BASIC]
|
|
|
|
good-names = [
|
|
|
|
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n",
|
|
|
|
"p", "q", "r", "s", "t", "u", "v", "w", "ex", "Run", "_", "x", "y", "z"
|
|
|
|
]
|
|
|
|
|
2022-07-27 23:42:29 +03:00
|
|
|
[tool.pylint.messages_control]
|
|
|
|
disable = [
|
|
|
|
"format",
|
|
|
|
"line-too-long",
|
|
|
|
"import-error",
|
|
|
|
"no-name-in-module",
|
|
|
|
"no-member",
|
2022-11-14 21:00:25 +03:00
|
|
|
"too-many-arguments",
|
|
|
|
"too-many-locals",
|
2022-07-27 23:42:29 +03:00
|
|
|
"too-few-public-methods",
|
|
|
|
"missing-docstring",
|
|
|
|
"fixme",
|
|
|
|
]
|
2022-05-16 23:26:56 +03:00
|
|
|
|
|
|
|
[tool.pyright]
|
|
|
|
exclude = ["onnxruntime/core/flatbuffers/*"]
|
|
|
|
reportMissingImports = false
|
2023-03-25 01:29:03 +03:00
|
|
|
|
|
|
|
[tool.ruff]
|
|
|
|
# NOTE: Do not create an exclude list. Edit .lintrunner.toml instead
|
2023-04-06 07:43:51 +03:00
|
|
|
target-version = "py38"
|
2024-03-13 20:00:32 +03:00
|
|
|
|
|
|
|
[tool.ruff.lint]
|
2023-03-25 01:29:03 +03:00
|
|
|
select = [
|
2023-07-21 22:53:41 +03:00
|
|
|
"B", # flake8-bugbear
|
2023-03-25 01:29:03 +03:00
|
|
|
"E", # pycodestyle
|
|
|
|
"F", # Pyflakes
|
2024-03-13 20:00:32 +03:00
|
|
|
"FURB", # refurb
|
|
|
|
"G", # flake8-logging-format
|
2023-04-17 20:11:44 +03:00
|
|
|
"ISC", # flake8-implicit-str-concat
|
2023-07-21 22:53:41 +03:00
|
|
|
"N", # pep8-naming
|
|
|
|
"NPY", # numpy
|
|
|
|
"PERF", # Perflint
|
2024-03-13 20:00:32 +03:00
|
|
|
"PIE", # flake8-pie
|
2023-07-21 22:53:41 +03:00
|
|
|
"PLC", # pylint conventions
|
|
|
|
"PLE", # pylint errors
|
|
|
|
"PLW", # pylint warnings
|
2024-03-13 20:00:32 +03:00
|
|
|
"PYI", # flake8-pyi
|
2023-03-25 01:29:03 +03:00
|
|
|
"RUF", # Ruff-specific rules
|
|
|
|
"SIM", # flake8-simplify
|
2024-03-13 20:00:32 +03:00
|
|
|
"SLOT", # flake8-slots
|
2023-07-21 22:53:41 +03:00
|
|
|
"T10", # flake8-debugger
|
2023-03-25 01:29:03 +03:00
|
|
|
"UP", # pyupgrade
|
2023-07-21 22:53:41 +03:00
|
|
|
"W", # pycodestyle
|
|
|
|
"YTT", # flake8-2020
|
2023-03-25 01:29:03 +03:00
|
|
|
]
|
|
|
|
# NOTE: Refrain from growing the ignore list unless for exceptional cases.
|
|
|
|
# Always include a comment to explain why.
|
|
|
|
ignore = [
|
2023-04-17 20:11:44 +03:00
|
|
|
"B028", # FIXME: Add stacklevel to warnings
|
2023-03-25 01:29:03 +03:00
|
|
|
"E501", # Line length controlled by black
|
2024-03-13 20:00:32 +03:00
|
|
|
"G004", # FIXME: Enable when the rule can be autofixed
|
2023-03-25 01:29:03 +03:00
|
|
|
"N803", # Argument casing
|
|
|
|
"N812", # Allow import torch.nn.functional as F
|
2024-07-25 03:48:22 +03:00
|
|
|
"N813", # Allow importing camelcase names in lowercase
|
2023-03-25 01:29:03 +03:00
|
|
|
"N999", # Module names
|
2023-03-28 06:37:53 +03:00
|
|
|
"NPY002", # np.random.Generator may not always fit our use cases
|
2023-07-26 01:38:22 +03:00
|
|
|
"PERF203", # "try-except-in-loop" only affects Python <3.11, and the improvement is minor; can have false positives
|
|
|
|
"PERF401", # List comprehensions are not always readable
|
2024-03-13 20:00:32 +03:00
|
|
|
"PYI041", # May create confusion
|
|
|
|
"PYI024", # May create confusion
|
2023-03-25 01:29:03 +03:00
|
|
|
"SIM102", # We don't perfer always combining if branches
|
2024-07-24 21:50:11 +03:00
|
|
|
"SIM103", # Do not collapse if-else
|
2023-03-25 01:29:03 +03:00
|
|
|
"SIM108", # We don't encourage ternary operators
|
|
|
|
"SIM114", # Don't combine if branches for debugability
|
|
|
|
"SIM116", # Don't use dict lookup to replace if-else
|
|
|
|
]
|
|
|
|
ignore-init-module-imports = true
|
|
|
|
unfixable = [
|
|
|
|
"F401", # Unused imports
|
2023-04-17 20:11:44 +03:00
|
|
|
"SIM112", # Use upper case for env vars
|
2023-03-25 01:29:03 +03:00
|
|
|
]
|
|
|
|
|
2024-03-13 20:00:32 +03:00
|
|
|
[tool.ruff.lint.per-file-ignores]
|
2023-03-25 01:29:03 +03:00
|
|
|
# NOTE: Refrain from growing the ignore list unless for exceptional cases.
|
|
|
|
# Prefer inline ignores with `noqa: xxx`.
|
|
|
|
# Eventually this list should become empty.
|
|
|
|
"orttraining/orttraining/test/**" = ["N802"] # Function casing
|
2023-04-17 20:11:44 +03:00
|
|
|
"tools/nuget/generate_nuspec_for_native_nuget.py" = ["ISC003"] # Too many errors to fix
|
[CUDA] Add SparseAttention operator for Phi-3-small (#20216)
### Description
Add CUDA implementation for block sparse attention for Phi-3-small.
Block sparse attention was proposed in [Sparse
Transformers](https://arxiv.org/pdf/1904.10509) by OpenAI, and also
adopted in [BigBird](https://arxiv.org/pdf/2007.14062) with different
sparse layout.
In Phi-3-small, the sparse layout is static, and works with
unidirectional (causal) attention.
Compared to dense attention, the benefit of block sparse is to speed up
both training and inference. It could save memory thus support longer
context length.
- [x] Add operator spec and shape inference
- [x] Symbolic shape inference
- [x] Refactor GroupQueryAttention to expose common kernels for kv cache
concatenation, q/k/v transpose etc.
- [x] Add cuda kernel to convert block mask to CSR format
- [x] Add cuda kernel to generate position ids
- [x] Add compile script and template files to convert triton kernel to
cubin and dispatcher.
- [x] Add triton kernel v1 for prompt
- [x] Add triton kernel v2 for token generation and support padding
- [x] Update IO Binding Helper to allow buffer sharing.
- [x] Test relevance
- [x] Test performance
### Performance
Test in A100-SXM4-80GB with `batch_size=4, num_heads=32,
max_seq_len=8192, head_size=128, sparse_block_size=64, local_blocks=16,
vert_stride=8, num_layout=8`
We compare sparse attention to corresponding GQA with local attention
windows size 1024, or GQA with dense causal.
Average latency in milliseconds (for fused attention kernel used in
prompt prefilling):
seq_len | GQA-Dense | GQA-Local | SparseAttention
-- | -- | -- | --
64 | 0.0465 | 0.0722 | 0.0641
128 | 0.0618 | 0.0787 | 0.0672
256 | 0.1086 | 0.1076 | 0.0943
512 | 0.2535 | 0.2487 | 0.1676
1024 | 0.7042 | 0.7050 | 0.3800
2048 | 2.4125 | 1.9316 | 0.8966
4096 | 8.9346 | 4.5699 | 2.1129
8192 | 40.5401 | 10.3508 | 5.1748
Average latency in milliseconds (for fused attention kernel used in
token generation:
past_seq_len | GQA-Dense | GQA-Local | SparseAttention
-- | -- | -- | --
64 | 0.0186 | 0.0186 | 0.0870
128 | 0.0408 | 0.0466 | 0.1165
256 | 0.0530 | 0.0592 | 0.0988
512 | 0.0445| 0.0447 | 0.1150
1024 | 0.0634 | 0.0640 | 0.1454
2048 | 0.1027 | 0.0637 | 0.1589
4096 | 0.1789 | 0.0631 | 0.1806
8192 | 0.3288 | 0.0655 | 0.2146
We can see that the kernel for token generation still have room to
improve.
#### Limitations
Only support right-side padding and unidirectional attention.
The following are not supported in the first version:
(1) Packed mode like PackedMultiHeadAttention where input has been
removed padding.
(2) paged attention.
(3) bidirectional attention.
(4) GPU compute capacity that is not 8.0, 8.6 and 8.9.
(5) Left side padding.
Some of these limitations will be removed in the future (may be in a new
operator).
2024-04-30 19:06:29 +03:00
|
|
|
"onnxruntime/contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_triton.py" = ["N806"] # use of Q, K and V in triton script
|
|
|
|
"onnxruntime/contrib_ops/cuda/sparse/sparse_attention_v2/sparse_attention_v2_triton.py" = ["N806"] # use of Q, K and V in triton script
|
2023-08-08 13:18:48 +03:00
|
|
|
"onnxruntime/test/python/quantization/test_op_gemm.py" = ["N806"] # use of A for a matrix
|
|
|
|
"onnxruntime/test/python/quantization/op_test_utils.py" = ["N806", "PERF203", "RUF012"] # use of A for a matrix
|
2023-10-27 05:29:27 +03:00
|
|
|
"orttraining/orttraining/python/training/ort_triton/kernel/_flash_attn.py" = ["N806", "PLW2901", "ISC001", "E731"] # Long triton code from other repo.
|