fix init_device_mesh for torch 2.4 (#6614)

Start torch 2.4, in
[`init_device_mesh()`](de4c2a3b4e/torch/distributed/device_mesh.py (L915))
,device type with a GPU index, such as "cuda:0", is not allowed.


![image](https://github.com/user-attachments/assets/1ddb61bf-8a15-4e0a-9115-a3681d7f19ff)

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
Co-authored-by: Masahiro Tanaka <mtanaka@microsoft.com>
This commit is contained in:
Lzhang-hub 2024-10-24 04:29:30 +08:00 коммит произвёл GitHub
Родитель e06bb518aa
Коммит 6e6563d3c8
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
1 изменённых файлов: 8 добавлений и 3 удалений

Просмотреть файл

@ -390,9 +390,14 @@ class TorchBackend(Backend):
if not required_torch_version(min_version=2.2):
raise RuntimeError(f"Current torch version does not have device mesh"
f"api (torch.__version__: {torch.__version__})")
return torch.distributed.device_mesh.init_device_mesh(get_accelerator().current_device_name(),
mesh_shape,
mesh_dim_names=mesh_dim_names)
if not required_torch_version(max_version=2.4):
return torch.distributed.device_mesh.init_device_mesh(get_accelerator().device_name(),
mesh_shape,
mesh_dim_names=mesh_dim_names)
else:
return torch.distributed.device_mesh.init_device_mesh(get_accelerator().current_device_name(),
mesh_shape,
mesh_dim_names=mesh_dim_names)
# This will become a light-weight wrapper around torch.distributed functions