pipe engine _aggregate_total_loss: more efficient loss concatenation (#4327)

* _aggregate_total_loss: more efficient loss concatenation

optimize _aggregate_total_loss function in order to remove dependancy
of copying from device to host and back to device.
This reduce the runtime on the host.

* Fixing the if/else block on which the optimization should take place

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
This commit is contained in:
Nadav Elyahu 2023-10-23 20:53:32 +03:00 коммит произвёл GitHub
Родитель 0f2338f7b8
Коммит a02de228d0
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 1 добавлений и 1 удалений

Просмотреть файл

@ -549,7 +549,7 @@ class PipelineEngine(DeepSpeedEngine):
agg_loss /= self.dp_world_size
assert self.global_rank in self.grid.pp_group
losses = torch.Tensor([self.dp_group_loss, agg_loss]).to(self.device)
losses = torch.stack([self.dp_group_loss, agg_loss])
if self.is_pipe_parallel:
dist.broadcast(tensor=losses, src=self.global_rank, group=self.mpu.get_pipe_parallel_group())
else: