зеркало из https://github.com/microsoft/DeepSpeed.git
Remove synchronize calls from allgather params (#5516)
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
This commit is contained in:
Родитель
695d79ea06
Коммит
0a1740386f
|
@ -56,7 +56,7 @@ class NoGatherHandle:
|
|||
self.__param = param
|
||||
|
||||
def wait(self) -> None:
|
||||
if not get_accelerator().is_synchronized_device():
|
||||
if not get_accelerator().resolves_data_dependency():
|
||||
get_accelerator().current_stream().synchronize()
|
||||
self.__param.ds_status = ZeroParamStatus.AVAILABLE
|
||||
|
||||
|
@ -82,7 +82,7 @@ class NoGatherCoalescedHandle:
|
|||
if self.__complete:
|
||||
return
|
||||
|
||||
if not get_accelerator().is_synchronized_device():
|
||||
if not get_accelerator().resolves_data_dependency():
|
||||
get_accelerator().current_stream().synchronize()
|
||||
for param in self.__params:
|
||||
assert param.ds_status == ZeroParamStatus.INFLIGHT, f"expected param {param.ds_summary()} to be inflight"
|
||||
|
@ -1737,7 +1737,8 @@ class Init(InsertPostInitMethodToModuleSubClasses):
|
|||
f'After allocate allgather param {debug_param2name_id_shape_status(param)} {aligned_param_size} {partition_size} ',
|
||||
force=False)
|
||||
|
||||
get_accelerator().synchronize()
|
||||
if not get_accelerator().resolves_data_dependency():
|
||||
get_accelerator().synchronize()
|
||||
|
||||
print_rank_0(
|
||||
f"{'--'* hierarchy}----allgather param with {debug_param2name_id_shape_status(param)} partition size={partition_size}"
|
||||
|
@ -1870,7 +1871,8 @@ class Init(InsertPostInitMethodToModuleSubClasses):
|
|||
param.data = gathered_tensor.narrow(0, 0, param.ds_numel).view(param.ds_shape).data
|
||||
|
||||
# guarantee the communication to be completed
|
||||
get_accelerator().synchronize()
|
||||
if not get_accelerator().resolves_data_dependency():
|
||||
get_accelerator().synchronize()
|
||||
|
||||
return None
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче