зеркало из https://github.com/microsoft/DeepSpeed.git
fix init_device_mesh for torch 2.4 (#6614)
Start torch 2.4, in
[`init_device_mesh()`](de4c2a3b4e/torch/distributed/device_mesh.py (L915)
)
,device type with a GPU index, such as "cuda:0", is not allowed.
![image](https://github.com/user-attachments/assets/1ddb61bf-8a15-4e0a-9115-a3681d7f19ff)
---------
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
Co-authored-by: Masahiro Tanaka <mtanaka@microsoft.com>
This commit is contained in:
Родитель
e06bb518aa
Коммит
6e6563d3c8
|
@ -390,9 +390,14 @@ class TorchBackend(Backend):
|
|||
if not required_torch_version(min_version=2.2):
|
||||
raise RuntimeError(f"Current torch version does not have device mesh"
|
||||
f"api (torch.__version__: {torch.__version__})")
|
||||
return torch.distributed.device_mesh.init_device_mesh(get_accelerator().current_device_name(),
|
||||
mesh_shape,
|
||||
mesh_dim_names=mesh_dim_names)
|
||||
if not required_torch_version(max_version=2.4):
|
||||
return torch.distributed.device_mesh.init_device_mesh(get_accelerator().device_name(),
|
||||
mesh_shape,
|
||||
mesh_dim_names=mesh_dim_names)
|
||||
else:
|
||||
return torch.distributed.device_mesh.init_device_mesh(get_accelerator().current_device_name(),
|
||||
mesh_shape,
|
||||
mesh_dim_names=mesh_dim_names)
|
||||
|
||||
|
||||
# This will become a light-weight wrapper around torch.distributed functions
|
||||
|
|
Загрузка…
Ссылка в новой задаче