Allow accelerator to instantiate the device (#5255)

when instantiating torch.device for HPU it cannot be fed with HPU:1
annotation, but only "HPU".
moving the logic to accelerator will allow to solve this issue, with
single line change.

---------

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Joe Mayer <114769929+jomayeri@users.noreply.github.com>
This commit is contained in:
Nadav Elyahu 2024-08-15 18:01:27 +03:00 коммит произвёл GitHub
Родитель 4ba49ddad8
Коммит eb07d41f95
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
4 изменённых файлов: 4 добавлений и 7 удалений

1
.github/workflows/hpu-gaudi2.yml поставляемый
Просмотреть файл

@ -68,7 +68,6 @@ jobs:
(test_flops_profiler.py and test_flops_profiler_in_inference)
test_get_optim_files.py
test_groups.py
test_init_on_device.py
test_partition_balanced.py
(test_adamw.py and TestAdamConfigs)
test_coalesced_collectives.py

Просмотреть файл

@ -42,9 +42,8 @@ class HPU_Accelerator(DeepSpeedAccelerator):
return True
def device_name(self, device_index=None):
if device_index is None:
return 'hpu'
return 'hpu:{}'.format(device_index)
# ignoring device_index.
return 'hpu'
def device(self, device_index=None):
return torch.device(self.device_name(device_index))

Просмотреть файл

@ -1009,13 +1009,13 @@ class DeepSpeedEngine(Module):
device_rank = args.device_rank if args is not None and hasattr(args, 'device_rank') else self.local_rank
if device_rank >= 0:
get_accelerator().set_device(device_rank)
self.device = torch.device(get_accelerator().device_name(), device_rank)
self.device = torch.device(get_accelerator().device_name(device_rank))
self.world_size = dist.get_world_size()
self.global_rank = dist.get_rank()
else:
self.world_size = 1
self.global_rank = 0
self.device = torch.device(get_accelerator().device_name())
self.device = get_accelerator().device()
# Configure based on command line arguments
def _configure_with_arguments(self, args, mpu):

Просмотреть файл

@ -68,7 +68,6 @@ def get_lst_from_rank0(lst: List[int]) -> None:
lst_tensor = torch.tensor(
lst if dist.get_rank() == 0 else [-1] * len(lst),
dtype=int,
# device=get_accelerator().current_device_name(),
device=torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"])),
requires_grad=False,
)