зеркало из https://github.com/microsoft/DeepSpeed.git
Skip Triton import for AMD (#5110)
When testing DeepSpeed inference on an `AMD Instinct MI250X/MI250` GPU, the `pytorch-triton-rocm` module would break the `torch.cuda` device API. To address this, importing `triton` is skipped when the GPU is determined to be `AMD`. This change allows DeepSpeed to be executed on an AMD GPU w/o kernel injection in the DeepSpeedExamples [text-generation example](https://github.com/microsoft/DeepSpeedExamples/tree/master/inference/huggingface/text-generation) using the following command: ```bash deepspeed --num_gpus 1 inference-test.py --model facebook/opt-125m ``` TODO: Root-cause the interaction between `pytorch-triton-rocm` and DeepSpeed to understand why this is causing the `torch.cuda` device API to break.
This commit is contained in:
Родитель
0a10bd427e
Коммит
d04a8386d1
|
@ -12,10 +12,14 @@ from torch.optim import Optimizer
|
|||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from packaging import version as pkg_version
|
||||
|
||||
try:
|
||||
import triton # noqa: F401 # type: ignore
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
# Skip Triton import for AMD due to pytorch-triton-rocm module breaking device API in DeepSpeed
|
||||
if not (hasattr(torch.version, 'hip') and torch.version.hip is not None):
|
||||
try:
|
||||
import triton # noqa: F401 # type: ignore
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
else:
|
||||
HAS_TRITON = False
|
||||
|
||||
from . import ops
|
||||
|
|
Загрузка…
Ссылка в новой задаче