Remove synchronize calls from allgather params (#5516)

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
This commit is contained in:
Liran Bachar 2024-05-21 18:01:20 +03:00 коммит произвёл GitHub
Родитель 695d79ea06
Коммит 0a1740386f
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
1 изменённых файлов: 6 добавлений и 4 удалений

Просмотреть файл

@ -56,7 +56,7 @@ class NoGatherHandle:
self.__param = param
def wait(self) -> None:
if not get_accelerator().is_synchronized_device():
if not get_accelerator().resolves_data_dependency():
get_accelerator().current_stream().synchronize()
self.__param.ds_status = ZeroParamStatus.AVAILABLE
@ -82,7 +82,7 @@ class NoGatherCoalescedHandle:
if self.__complete:
return
if not get_accelerator().is_synchronized_device():
if not get_accelerator().resolves_data_dependency():
get_accelerator().current_stream().synchronize()
for param in self.__params:
assert param.ds_status == ZeroParamStatus.INFLIGHT, f"expected param {param.ds_summary()} to be inflight"
@ -1737,7 +1737,8 @@ class Init(InsertPostInitMethodToModuleSubClasses):
f'After allocate allgather param {debug_param2name_id_shape_status(param)} {aligned_param_size} {partition_size} ',
force=False)
get_accelerator().synchronize()
if not get_accelerator().resolves_data_dependency():
get_accelerator().synchronize()
print_rank_0(
f"{'--'* hierarchy}----allgather param with {debug_param2name_id_shape_status(param)} partition size={partition_size}"
@ -1870,7 +1871,8 @@ class Init(InsertPostInitMethodToModuleSubClasses):
param.data = gathered_tensor.narrow(0, 0, param.ds_numel).view(param.ds_shape).data
# guarantee the communication to be completed
get_accelerator().synchronize()
if not get_accelerator().resolves_data_dependency():
get_accelerator().synchronize()
return None