зеркало из https://github.com/microsoft/DeepSpeed.git
[NPU] Add NPU support for unit test (#4569)
Unit tests would fail or skip when device=npu, and we definitely want to test all these wonderful features by official unit tests. Here comes the commit to add NPU support for unit test. P.S. see what we have already done #4567. **What I do in this commit** 1. Just add npu logic branch feat: Add npu support for skip_on_arch in tests/unit/util.py feat: Add npu support for skip_on_cuda in tests/unit/util.py feat: Add npu support for tests/unit/common.py 2. Set_device of accelerator before deepspeed.init_distributed in tests/unit/common.py It would be friendlier and easier for other device like npu, if we can set_device of accelerator before init_distributed. Plus, setting device param before init sounds more reasonable. 3. Solve the problem of calling get_accelerator().random().fork_rng with non-cuda device Function `train_cifar()` in `tests/unit/alexnet_model.py` calls `get_accelerator().random().fork_rng` without passing `device_type` explicitly. Unfortunately, `torch.random.fork_rng()` has default value setting `device_type=cuda` and non-cuda devices would fail to run. So my solution is explicitly passing `device_type=get_accelerator().device_name()`, and either cuda or non-cuda devices would perform correctly. --------- Co-authored-by: ryan <ruanzhixiang1@huawei.com> Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
This commit is contained in:
Родитель
0a6095faa0
Коммит
4b7cae7bea
|
@ -4,4 +4,4 @@
|
|||
# DeepSpeed Team
|
||||
|
||||
from .abstract_accelerator import DeepSpeedAccelerator
|
||||
from .real_accelerator import get_accelerator, set_accelerator
|
||||
from .real_accelerator import get_accelerator, set_accelerator, is_current_accelerator_supported
|
||||
|
|
|
@ -20,6 +20,8 @@ try:
|
|||
except ImportError as e:
|
||||
dsa2 = None
|
||||
|
||||
SUPPORTED_ACCELERATOR_LIST = ['cuda', 'cpu', 'xpu', 'npu', 'mps']
|
||||
|
||||
ds_accelerator = None
|
||||
|
||||
|
||||
|
@ -34,7 +36,7 @@ def _validate_accelerator(accel_obj):
|
|||
# accelerator.abstractor_accelerator
|
||||
# or deepspeed.accelerator.abstract_accelerator, consider accel_obj
|
||||
# is a conforming object
|
||||
if not ((dsa1 != None and isinstance(accel_obj, dsa1)) or (dsa2 != None and isinstance(accel_obj, dsa2))):
|
||||
if not ((dsa1 is not None and isinstance(accel_obj, dsa1)) or (dsa2 is not None and isinstance(accel_obj, dsa2))):
|
||||
raise AssertionError(f"{accel_obj.__class__.__name__} accelerator is not subclass of DeepSpeedAccelerator")
|
||||
|
||||
# TODO: turn off is_available test since this breaks tests
|
||||
|
@ -42,6 +44,10 @@ def _validate_accelerator(accel_obj):
|
|||
# f'{accel_obj.__class__.__name__} accelerator fails is_available() test'
|
||||
|
||||
|
||||
def is_current_accelerator_supported():
|
||||
return get_accelerator() in SUPPORTED_ACCELERATOR_LIST
|
||||
|
||||
|
||||
def get_accelerator():
|
||||
global ds_accelerator
|
||||
if ds_accelerator is not None:
|
||||
|
@ -50,7 +56,6 @@ def get_accelerator():
|
|||
accelerator_name = None
|
||||
ds_set_method = None
|
||||
# 1. Detect whether there is override of DeepSpeed accelerators from environment variable.
|
||||
DS_ACCELERATOR_LIST = ['cuda', 'cpu', 'xpu', 'npu', 'mps']
|
||||
if "DS_ACCELERATOR" in os.environ.keys():
|
||||
accelerator_name = os.environ["DS_ACCELERATOR"]
|
||||
if accelerator_name == "xpu":
|
||||
|
@ -79,15 +84,13 @@ def get_accelerator():
|
|||
torch.mps.current_allocated_memory()
|
||||
except (RuntimeError, ImportError) as e:
|
||||
raise ValueError(f"MPS_Accelerator requires torch.mps, which is not installed on this system.")
|
||||
elif accelerator_name == "cuda":
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
f'DS_ACCELERATOR must be one of {DS_ACCELERATOR_LIST}. Value "{accelerator_name}" is not supported')
|
||||
elif is_current_accelerator_supported():
|
||||
raise ValueError(f'DS_ACCELERATOR must be one of {SUPPORTED_ACCELERATOR_LIST}. '
|
||||
f'Value "{accelerator_name}" is not supported')
|
||||
ds_set_method = "override"
|
||||
|
||||
# 2. If no override, detect which accelerator to use automatically
|
||||
if accelerator_name == None:
|
||||
if accelerator_name is None:
|
||||
# We need a way to choose among different accelerator types.
|
||||
# Currently we detect which accelerator extension is installed
|
||||
# in the environment and use it if the installing answer is True.
|
||||
|
@ -105,21 +108,21 @@ def get_accelerator():
|
|||
accelerator_name = "xpu"
|
||||
except ImportError as e:
|
||||
pass
|
||||
if accelerator_name == None:
|
||||
if accelerator_name is None:
|
||||
try:
|
||||
import intel_extension_for_pytorch # noqa: F401,F811 # type: ignore
|
||||
|
||||
accelerator_name = "cpu"
|
||||
except ImportError as e:
|
||||
pass
|
||||
if accelerator_name == None:
|
||||
if accelerator_name is None:
|
||||
try:
|
||||
import torch_npu # noqa: F401,F811 # type: ignore
|
||||
|
||||
accelerator_name = "npu"
|
||||
except ImportError as e:
|
||||
pass
|
||||
if accelerator_name == None:
|
||||
if accelerator_name is None:
|
||||
try:
|
||||
import torch.mps
|
||||
|
||||
|
@ -128,7 +131,7 @@ def get_accelerator():
|
|||
accelerator_name = "mps"
|
||||
except (RuntimeError, ImportError) as e:
|
||||
pass
|
||||
if accelerator_name == None:
|
||||
if accelerator_name is None:
|
||||
accelerator_name = "cuda"
|
||||
|
||||
ds_set_method = "auto detect"
|
||||
|
|
|
@ -111,7 +111,8 @@ def cifar_trainset(fp16=False):
|
|||
|
||||
|
||||
def train_cifar(model, config, num_steps=400, average_dp_losses=True, fp16=True, seed=123):
|
||||
with get_accelerator().random().fork_rng(devices=[get_accelerator().current_device_name()]):
|
||||
with get_accelerator().random().fork_rng(devices=[get_accelerator().current_device_name()],
|
||||
device_type=get_accelerator().device_name()):
|
||||
ds_utils.set_random_seed(seed)
|
||||
|
||||
# disable dropout
|
||||
|
|
|
@ -81,6 +81,9 @@ def set_accelerator_visible():
|
|||
match = re.search('Device Type.*GPU', line)
|
||||
if match:
|
||||
num_accelerators += 1
|
||||
elif get_accelerator().device_name() == 'npu':
|
||||
npu_smi = subprocess.check_output(['npu-smi', 'info', '-l'])
|
||||
num_accelerators = int(npu_smi.decode('utf-8').strip().split('\n')[0].split(':')[1].strip())
|
||||
else:
|
||||
assert get_accelerator().device_name() == 'cpu'
|
||||
cpu_sockets = int(
|
||||
|
@ -204,13 +207,13 @@ class DistributedExec(ABC):
|
|||
if get_accelerator().is_available():
|
||||
set_accelerator_visible()
|
||||
|
||||
if get_accelerator().is_available():
|
||||
get_accelerator().set_device(local_rank)
|
||||
|
||||
if self.init_distributed:
|
||||
deepspeed.init_distributed(dist_backend=self.backend)
|
||||
dist.barrier()
|
||||
|
||||
if get_accelerator().is_available():
|
||||
get_accelerator().set_device(local_rank)
|
||||
|
||||
try:
|
||||
self.run(**self._fixture_kwargs)
|
||||
except BaseException as e:
|
||||
|
|
|
@ -5,29 +5,29 @@
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import deepspeed
|
||||
from deepspeed.accelerator import get_accelerator, is_current_accelerator_supported
|
||||
from deepspeed.git_version_info import torch_info
|
||||
from packaging import version as pkg_version
|
||||
|
||||
|
||||
def skip_on_arch(min_arch=7):
|
||||
if deepspeed.accelerator.get_accelerator().device_name() == 'cuda':
|
||||
if get_accelerator().device_name() == 'cuda':
|
||||
if torch.cuda.get_device_capability()[0] < min_arch: #ignore-cuda
|
||||
pytest.skip(f"needs higher compute capability than {min_arch}")
|
||||
else:
|
||||
assert deepspeed.accelerator.get_accelerator().device_name() == 'xpu'
|
||||
assert is_current_accelerator_supported()
|
||||
return
|
||||
|
||||
|
||||
def skip_on_cuda(valid_cuda):
|
||||
split_version = lambda x: map(int, x.split('.')[:2])
|
||||
if deepspeed.accelerator.get_accelerator().device_name() == 'cuda':
|
||||
if get_accelerator().device_name() == 'cuda':
|
||||
CUDA_MAJOR, CUDA_MINOR = split_version(torch_info['cuda_version'])
|
||||
CUDA_VERSION = (CUDA_MAJOR * 10) + CUDA_MINOR
|
||||
if valid_cuda.count(CUDA_VERSION) == 0:
|
||||
pytest.skip(f"requires cuda versions {valid_cuda}")
|
||||
else:
|
||||
assert deepspeed.accelerator.get_accelerator().device_name() == 'xpu'
|
||||
assert is_current_accelerator_supported()
|
||||
return
|
||||
|
||||
|
||||
|
@ -43,8 +43,14 @@ def bf16_required_version_check(accelerator_check=True):
|
|||
else:
|
||||
accelerator_pass = True
|
||||
|
||||
if (TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)) and (CUDA_MAJOR >= 11) and (
|
||||
NCCL_MAJOR > 2 or (NCCL_MAJOR == 2 and NCCL_MINOR >= 10)) and accelerator_pass:
|
||||
torch_version_available = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
|
||||
cuda_version_available = CUDA_MAJOR >= 11
|
||||
nccl_version_available = NCCL_MAJOR > 2 or (NCCL_MAJOR == 2 and NCCL_MINOR >= 10)
|
||||
npu_available = get_accelerator().device_name() == 'npu'
|
||||
|
||||
if torch_version_available and cuda_version_available and nccl_version_available and accelerator_pass:
|
||||
return True
|
||||
elif npu_available:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
|
Загрузка…
Ссылка в новой задаче